Commit 1606ac08 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 58d65692 67a00a6c
...@@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use ...@@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use
implementations, respectively. implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments. - **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention. - **FlashAttention** support greatly speeds up MSA attention.
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.
## Installation (Linux) ## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite), alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)} and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. You'll need `git-lfs` to download OpenFold parameters. installed on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`. Finally, some download scripts require `aria2c` and `aws`.
This package is currently supported for CUDA 11 and Pytorch 1.12 This package is currently supported for CUDA 11 and Pytorch 1.12
...@@ -114,7 +115,13 @@ the model, consult `run_pretrained_openfold.py`. ...@@ -114,7 +115,13 @@ the model, consult `run_pretrained_openfold.py`.
### Inference ### Inference
To run inference on a sequence or multiple sequences using a set of DeepMind's To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.: pretrained parameters, first download the OpenFold weights e.g.:
```bash
bash scripts/download_openfold_params.sh openfold/resources
```
then run e.g.:
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
...@@ -286,6 +293,14 @@ python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/ ...@@ -286,6 +293,14 @@ python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command. In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
Then download the SoloSeq model weights, e.g.:
```bash
bash scripts/download_openfold_soloseq_params.sh openfold/resources
```
Now, you are ready to run inference: Now, you are ready to run inference:
```bash ```bash
python run_pretrained_openfold.py \ python run_pretrained_openfold.py \
...@@ -295,7 +310,7 @@ python run_pretrained_openfold.py \ ...@@ -295,7 +310,7 @@ python run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
``` ```
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure. For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"cpu_offload": true, "offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true "contiguous_gradients": true
}, },
"activation_checkpointing": { "activation_checkpointing": {
...@@ -20,5 +22,6 @@ ...@@ -20,5 +22,6 @@
"cpu_checkpointing": false, "cpu_checkpointing": false,
"profile": false "profile": false
}, },
"gradient_clipping": 0.1 "gradient_clipping": 0.1,
"zero_force_ds_cpu_optimizer": false
} }
...@@ -31,7 +31,7 @@ dependencies: ...@@ -31,7 +31,7 @@ dependencies:
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=1.12.*
- pip: - pip:
- deepspeed==0.5.10 - deepspeed==0.12.4
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "OpenFold.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -57,10 +50,12 @@ ...@@ -57,10 +50,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "rowN0bVYLe9n", "cellView": "form",
"cellView": "form" "id": "rowN0bVYLe9n"
}, },
"outputs": [],
"source": [ "source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n", "#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
...@@ -78,16 +73,16 @@ ...@@ -78,16 +73,16 @@
"\n", "\n",
"#@markdown After making your selections, execute this cell by pressing the\n", "#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left." "#@markdown *Play* button on the left."
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "woIxeCPygt7K", "cellView": "form",
"cellView": "form" "id": "woIxeCPygt7K"
}, },
"outputs": [],
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
...@@ -97,75 +92,54 @@ ...@@ -97,75 +92,54 @@
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
"\n", "\n",
"import sys\n", "import os, time\n",
"from IPython.utils import io\n", "from IPython.utils import io\n",
"import os\n", "from sys import version_info\n",
"import subprocess\n", "import subprocess\n",
"import tqdm.notebook\n",
"\n", "\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", "python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n", "\n",
"python_version = '.'.join(sys.version.split('.')[:2]) #get string like \"3.9\"\n", "\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n",
"\n",
"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" %shell sudo apt install --quiet --yes hmmer\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
"\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q conda==4.13.0\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python={python_version} \\\n",
" openmm=7.7.0 \\\n",
" pdbfixer \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79 \\\n",
" modelcif==0.7\n",
"\n", "\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo apt install --quiet --yes hmmer\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
"\n", "\n",
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n", "\n",
" # Install AWS CLI\n", " %shell mkdir -p /content/openfold/openfold/resourcees\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n", " \n",
" %shell unzip -qq awscliv2.zip\n", " commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n",
" %shell sudo ./aws/install\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" %shell rm awscliv2.zip\n", " \n",
" %shell rm -rf ./aws\n", " %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n",
"\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)"
" raise" ]
],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "VzJ5iMjTtoZw", "cellView": "form",
"cellView": "form" "id": "VzJ5iMjTtoZw"
}, },
"outputs": [],
"source": [ "source": [
"#@title Install OpenFold\n", "#@title Download model weights \n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
...@@ -180,13 +154,6 @@ ...@@ -180,13 +154,6 @@
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n",
" %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(weight_set == 'AlphaFold'):\n", " if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n", " %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
...@@ -194,7 +161,14 @@ ...@@ -194,7 +161,14 @@
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n", " --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n", " %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
" elif(weight_set == 'OpenFold'):\n", " elif(weight_set == 'OpenFold'):\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
"\n",
" %shell aws s3 cp \\\n", " %shell aws s3 cp \\\n",
" --no-sign-request \\\n", " --no-sign-request \\\n",
" --region us-east-1 \\\n", " --region us-east-1 \\\n",
...@@ -203,14 +177,17 @@ ...@@ -203,14 +177,17 @@
" else:\n", " else:\n",
" raise ValueError(\"Invalid weight set\")\n", " raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)"
" raise" ]
],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"outputs": [],
"source": [ "source": [
"#@title Import Python packages\n", "#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
...@@ -219,8 +196,8 @@ ...@@ -219,8 +196,8 @@
"import unittest.mock\n", "import unittest.mock\n",
"import sys\n", "import sys\n",
"\n", "\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
"sys.path.append(f'/opt/conda/lib/python{python_version}/site-packages')\n",
"\n", "\n",
"# Allows us to skip installing these packages\n", "# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n", "unnecessary_modules = [\n",
...@@ -245,6 +222,10 @@ ...@@ -245,6 +222,10 @@
"import py3Dmol\n", "import py3Dmol\n",
"import torch\n", "import torch\n",
"import shutil\n", "import shutil\n",
"import tqdm\n",
"import tqdm.notebook\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n", "\n",
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n", "# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n", "# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
...@@ -280,13 +261,7 @@ ...@@ -280,13 +261,7 @@
"from IPython import display\n", "from IPython import display\n",
"from ipywidgets import GridspecLayout\n", "from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output" "from ipywidgets import Output"
], ]
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -301,10 +276,12 @@ ...@@ -301,10 +276,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "2tTeTTsLKPjB", "cellView": "form",
"cellView": "form" "id": "2tTeTTsLKPjB"
}, },
"outputs": [],
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title Search against genetic databases\n",
"\n", "\n",
...@@ -420,16 +397,16 @@ ...@@ -420,16 +397,16 @@
"plt.ylabel('Non-Gap Count')\n", "plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n", "plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.show()" "plt.show()"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "XUo6foMQxwS2", "cellView": "form",
"cellView": "form" "id": "XUo6foMQxwS2"
}, },
"outputs": [],
"source": [ "source": [
"#@title Run OpenFold and download prediction\n", "#@title Run OpenFold and download prediction\n",
"\n", "\n",
...@@ -693,9 +670,7 @@ ...@@ -693,9 +670,7 @@
"# --- Download the predictions ---\n", "# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n", "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')" "files.download(f'{output_dir}.zip')"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -789,5 +764,24 @@ ...@@ -789,5 +764,24 @@
"* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details." "* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details."
] ]
} }
] ],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }
...@@ -29,19 +29,28 @@ def enforce_config_constraints(config): ...@@ -29,19 +29,28 @@ def enforce_config_constraints(config):
( (
"globals.use_lma", "globals.use_lma",
"globals.use_flash", "globals.use_flash",
"globals.use_deepspeed_evo_attention"
), ),
] ]
for s1, s2 in mutually_exclusive_bools: for options in mutually_exclusive_bools:
s1_setting = string_to_setting(s1) option_settings = [string_to_setting(o) for o in options]
s2_setting = string_to_setting(s2) if sum(option_settings) > 1:
if(s1_setting and s2_setting): raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed): if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed") raise ValueError("use_flash requires that FlashAttention is installed")
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
if( if(
config.globals.offload_inference and config.globals.offload_inference and
not config.model.template.average_templates not config.model.template.average_templates
...@@ -158,10 +167,12 @@ def model_config( ...@@ -158,10 +167,12 @@ def model_config(
if name == "seqemb_initial_training": if name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1 c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning": elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1 c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
...@@ -218,7 +229,8 @@ def model_config( ...@@ -218,7 +229,8 @@ def model_config(
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
c.globals.use_lma = True # Default to DeepSpeed memory-efficient attention kernel unless use_lma is explicitly set
c.globals.use_deepspeed_evo_attention = True if not c.globals.use_lma else False
c.globals.use_flash = False c.globals.use_flash = False
c.model.template.offload_inference = True c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False c.model.template.template_pair_stack.tune_chunk_size = False
...@@ -338,6 +350,11 @@ config = mlc.ConfigDict( ...@@ -338,6 +350,11 @@ config = mlc.ConfigDict(
"true_msa": [NUM_MSA_SEQ, NUM_RES], "true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [], "use_clamped_fape": [],
}, },
"block_delete_msa": {
"msa_fraction_per_block": 0.3,
"randomize_num_blocks": False,
"num_blocks": 5,
},
"masked_msa": { "masked_msa": {
"profile_prob": 0.1, "profile_prob": 0.1,
"same_prob": 0.1, "same_prob": 0.1,
...@@ -382,6 +399,7 @@ config = mlc.ConfigDict( ...@@ -382,6 +399,7 @@ config = mlc.ConfigDict(
"predict": { "predict": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512, "max_msa_clusters": 512,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -397,6 +415,7 @@ config = mlc.ConfigDict( ...@@ -397,6 +415,7 @@ config = mlc.ConfigDict(
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -412,6 +431,7 @@ config = mlc.ConfigDict( ...@@ -412,6 +431,7 @@ config = mlc.ConfigDict(
"train": { "train": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": True, "subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -441,11 +461,15 @@ config = mlc.ConfigDict( ...@@ -441,11 +461,15 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually # Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash. # exclusive with use_deepspeed_evo_attention and use_flash.
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with # Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues). # use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False, "use_flash": False,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
...@@ -799,6 +823,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -799,6 +823,7 @@ multimer_config_update = mlc.ConfigDict({
"train": { "train": {
"max_msa_clusters": 508, "max_msa_clusters": 508,
"max_extra_msa": 2048, "max_extra_msa": 2048,
"block_delete_msa" : False,
"crop_size": 640, "crop_size": 640,
"spatial_crop_prob": 0.5, "spatial_crop_prob": 0.5,
"interface_threshold": 10., "interface_threshold": 10.,
......
...@@ -255,28 +255,33 @@ def block_delete_msa(protein, config): ...@@ -255,28 +255,33 @@ def block_delete_msa(protein, config):
* config.msa_fraction_per_block * config.msa_fraction_per_block
).to(torch.int32) ).to(torch.int32)
if int(block_num_seq) == 0:
return protein
if config.randomize_num_blocks: if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform( nb = int(torch.randint(
0, config.num_blocks + 1 low=0,
).sample() high=config.num_blocks + 1,
size=(1,),
device=protein["msa"].device,
)[0])
else: else:
nb = config.num_blocks nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1) del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] del_indices = torch.unique(torch.reshape(del_blocks, [-1]))
# Make sure we keep the original sequence # Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None])) combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
uniques, counts = combined.unique(return_counts=True) uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1] keep_indices = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
assert int(keep_indices[0]) == 0
for k in MSA_FEATURE_NAMES: for k in MSA_FEATURE_NAMES:
if k in protein: if k in protein:
protein[k] = torch.gather(protein[k], keep_indices) protein[k] = torch.index_select(protein[k], 0, keep_indices)
return protein return protein
......
...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
if mode_cfg.block_delete_msa:
transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg: if "max_distillation_msa_clusters" in mode_cfg:
transforms.append( transforms.append(
data_transforms.sample_msa_distillation( data_transforms.sample_msa_distillation(
......
...@@ -657,6 +657,7 @@ class TemplateEmbedder(nn.Module): ...@@ -657,6 +657,7 @@ class TemplateEmbedder(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True, _mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False, use_lma=False,
inplace_safe=False inplace_safe=False
): ):
...@@ -707,6 +708,7 @@ class TemplateEmbedder(nn.Module): ...@@ -707,6 +708,7 @@ class TemplateEmbedder(nn.Module):
t_pair, t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -893,6 +895,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -893,6 +895,7 @@ class TemplateEmbedderMultimer(nn.Module):
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
_mask_trans=True, _mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False, use_lma=False,
inplace_safe=False inplace_safe=False
): ):
...@@ -967,6 +970,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -967,6 +970,7 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds["template_pair_embedding"], template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
......
...@@ -90,7 +90,6 @@ class MSATransition(nn.Module): ...@@ -90,7 +90,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward( def forward(
self, self,
m: torch.Tensor, m: torch.Tensor,
...@@ -179,6 +178,7 @@ class PairStack(nn.Module): ...@@ -179,6 +178,7 @@ class PairStack(nn.Module):
z: torch.Tensor, z: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -225,6 +225,7 @@ class PairStack(nn.Module): ...@@ -225,6 +225,7 @@ class PairStack(nn.Module):
mask=pair_mask, mask=pair_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -243,6 +244,7 @@ class PairStack(nn.Module): ...@@ -243,6 +244,7 @@ class PairStack(nn.Module):
mask=pair_mask.transpose(-1, -2), mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -360,6 +362,7 @@ class MSABlock(nn.Module, ABC): ...@@ -360,6 +362,7 @@ class MSABlock(nn.Module, ABC):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -423,6 +426,7 @@ class EvoformerBlock(MSABlock): ...@@ -423,6 +426,7 @@ class EvoformerBlock(MSABlock):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -462,6 +466,7 @@ class EvoformerBlock(MSABlock): ...@@ -462,6 +466,7 @@ class EvoformerBlock(MSABlock):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
...@@ -483,6 +488,7 @@ class EvoformerBlock(MSABlock): ...@@ -483,6 +488,7 @@ class EvoformerBlock(MSABlock):
m, m,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
), ),
...@@ -527,6 +533,7 @@ class EvoformerBlock(MSABlock): ...@@ -527,6 +533,7 @@ class EvoformerBlock(MSABlock):
z=input_tensors[1], z=input_tensors[1],
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -602,6 +609,7 @@ class ExtraMSABlock(MSABlock): ...@@ -602,6 +609,7 @@ class ExtraMSABlock(MSABlock):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -637,7 +645,8 @@ class ExtraMSABlock(MSABlock): ...@@ -637,7 +645,8 @@ class ExtraMSABlock(MSABlock):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_memory_efficient_kernel=not use_lma and m.is_cuda, use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
...@@ -709,6 +718,7 @@ class ExtraMSABlock(MSABlock): ...@@ -709,6 +718,7 @@ class ExtraMSABlock(MSABlock):
input_tensors[1], input_tensors[1],
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -853,6 +863,7 @@ class EvoformerStack(nn.Module): ...@@ -853,6 +863,7 @@ class EvoformerStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
...@@ -866,6 +877,7 @@ class EvoformerStack(nn.Module): ...@@ -866,6 +877,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -905,6 +917,7 @@ class EvoformerStack(nn.Module): ...@@ -905,6 +917,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -916,6 +929,7 @@ class EvoformerStack(nn.Module): ...@@ -916,6 +929,7 @@ class EvoformerStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -947,6 +961,7 @@ class EvoformerStack(nn.Module): ...@@ -947,6 +961,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -965,10 +980,15 @@ class EvoformerStack(nn.Module): ...@@ -965,10 +980,15 @@ class EvoformerStack(nn.Module):
chunk_size: chunk_size:
Inference-time subbatch size. Acts as a minimum if Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash: use_flash:
Whether to use FlashAttention where possible. Mutually Whether to use FlashAttention where possible. Mutually
exclusive with use_lma. exclusive with use_lma and use_deepspeed_evo_attention.
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
...@@ -981,6 +1001,7 @@ class EvoformerStack(nn.Module): ...@@ -981,6 +1001,7 @@ class EvoformerStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -1065,6 +1086,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1065,6 +1086,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
...@@ -1076,7 +1098,8 @@ class ExtraMSAStack(nn.Module): ...@@ -1076,7 +1098,8 @@ class ExtraMSAStack(nn.Module):
b, b,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -1113,6 +1136,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1113,6 +1136,7 @@ class ExtraMSAStack(nn.Module):
def _forward_offload(self, def _forward_offload(self,
input_tensors: Sequence[torch.Tensor], input_tensors: Sequence[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
...@@ -1125,6 +1149,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1125,6 +1149,7 @@ class ExtraMSAStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
...@@ -1151,6 +1176,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1151,6 +1176,7 @@ class ExtraMSAStack(nn.Module):
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -1162,6 +1188,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1162,6 +1188,7 @@ class ExtraMSAStack(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference use_lma: Whether to use low-memory attention during inference
msa_mask: msa_mask:
Optional [*, N_extra, N_res] MSA mask Optional [*, N_extra, N_res] MSA mask
...@@ -1175,6 +1202,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1175,6 +1202,7 @@ class ExtraMSAStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
......
...@@ -103,8 +103,8 @@ class AlphaFold(nn.Module): ...@@ -103,8 +103,8 @@ class AlphaFold(nn.Module):
**self.config["recycling_embedder"], **self.config["recycling_embedder"],
) )
if (self.template_config.enabled): if self.template_config.enabled:
if(self.globals.is_multimer): if self.globals.is_multimer:
self.template_embedder = TemplateEmbedderMultimer( self.template_embedder = TemplateEmbedderMultimer(
self.template_config, self.template_config,
) )
...@@ -113,14 +113,14 @@ class AlphaFold(nn.Module): ...@@ -113,14 +113,14 @@ class AlphaFold(nn.Module):
self.template_config, self.template_config,
) )
if (self.extra_msa_config.enabled): if self.extra_msa_config.enabled:
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"], **self.extra_msa_config["extra_msa_embedder"],
) )
self.extra_msa_stack = ExtraMSAStack( self.extra_msa_stack = ExtraMSAStack(
**self.extra_msa_config["extra_msa_stack"], **self.extra_msa_config["extra_msa_stack"],
) )
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**self.config["evoformer_stack"], **self.config["evoformer_stack"],
) )
...@@ -134,10 +134,10 @@ class AlphaFold(nn.Module): ...@@ -134,10 +134,10 @@ class AlphaFold(nn.Module):
) )
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe): def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer): if self.globals.is_multimer:
asym_id = feats["asym_id"] asym_id = feats["asym_id"]
multichain_mask_2d = ( multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :] asym_id[..., None] == asym_id[..., None, :]
) )
template_embeds = self.template_embedder( template_embeds = self.template_embedder(
batch, batch,
...@@ -146,19 +146,20 @@ class AlphaFold(nn.Module): ...@@ -146,19 +146,20 @@ class AlphaFold(nn.Module):
templ_dim, templ_dim,
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d, multichain_mask_2d=multichain_mask_2d,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans = self.config._mask_trans _mask_trans=self.config._mask_trans
) )
feats["template_torsion_angles_mask"] = ( feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"] template_embeds["template_mask"]
) )
else: else:
if (self.template_config.offload_templates): if self.template_config.offload_templates:
return embed_templates_offload(self, return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
elif (self.template_config.average_templates): elif self.template_config.average_templates:
return embed_templates_average(self, return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
...@@ -169,6 +170,7 @@ class AlphaFold(nn.Module): ...@@ -169,6 +170,7 @@ class AlphaFold(nn.Module):
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
templ_dim, templ_dim,
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans _mask_trans=self.config._mask_trans
...@@ -176,7 +178,7 @@ class AlphaFold(nn.Module): ...@@ -176,7 +178,7 @@ class AlphaFold(nn.Module):
return template_embeds return template_embeds
def tolerance_reached(self, prev_pos, next_pos, mask, no_batch_dims, eps=1e-8) -> bool: def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
""" """
Early stopping criteria based on criteria used in Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
...@@ -188,6 +190,7 @@ class AlphaFold(nn.Module): ...@@ -188,6 +190,7 @@ class AlphaFold(nn.Module):
Returns: Returns:
Whether to stop recycling early based on the desired tolerance. Whether to stop recycling early based on the desired tolerance.
""" """
def distances(points): def distances(points):
"""Compute all pairwise distances for a set of points.""" """Compute all pairwise distances for a set of points."""
d = points[..., None, :] - points[..., None, :, :] d = points[..., None, :] - points[..., None, :, :]
...@@ -210,7 +213,7 @@ class AlphaFold(nn.Module): ...@@ -210,7 +213,7 @@ class AlphaFold(nn.Module):
# This needs to be done manually for DeepSpeed's sake # This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype dtype = next(self.parameters()).dtype
for k in feats: for k in feats:
if(feats[k].dtype == torch.float32): if feats[k].dtype == torch.float32:
feats[k] = feats[k].to(dtype=dtype) feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input # Grab some data about the input
...@@ -219,7 +222,7 @@ class AlphaFold(nn.Module): ...@@ -219,7 +222,7 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout # Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints # The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
...@@ -283,7 +286,7 @@ class AlphaFold(nn.Module): ...@@ -283,7 +286,7 @@ class AlphaFold(nn.Module):
).to(dtype=z.dtype) ).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first # The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe): if self.globals.offload_inference and inplace_safe:
m = m.cpu() m = m.cpu()
z = z.cpu() z = z.cpu()
...@@ -298,7 +301,7 @@ class AlphaFold(nn.Module): ...@@ -298,7 +301,7 @@ class AlphaFold(nn.Module):
del pseudo_beta_x_prev del pseudo_beta_x_prev
if(self.globals.offload_inference and inplace_safe): if self.globals.offload_inference and inplace_safe:
m = m.to(m_1_prev_emb.device) m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device) z = z.to(z_prev.device)
...@@ -314,7 +317,7 @@ class AlphaFold(nn.Module): ...@@ -314,7 +317,7 @@ class AlphaFold(nn.Module):
del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
...@@ -330,24 +333,24 @@ class AlphaFold(nn.Module): ...@@ -330,24 +333,24 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = add(z, z = add(z,
template_embeds.pop("template_pair_embedding"), template_embeds.pop("template_pair_embedding"),
inplace_safe, inplace_safe,
) )
if( if (
"template_single_embedding" in template_embeds "template_single_embedding" in template_embeds
): ):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_single_embedding"]], [m, template_embeds["template_single_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
if(not self.globals.is_multimer): if not self.globals.is_multimer:
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
else: else:
...@@ -358,31 +361,32 @@ class AlphaFold(nn.Module): ...@@ -358,31 +361,32 @@ class AlphaFold(nn.Module):
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
if(self.globals.is_multimer): if self.globals.is_multimer:
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else: else:
extra_msa_fn = build_extra_msa_feat extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype) extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype)
a = self.extra_msa_embedder(extra_msa_feat) a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference): if self.globals.offload_inference:
# To allow the extra MSA stack (and later the evoformer) to # To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here # offload its inputs, we remove all references to them here
input_tensors = [a, z] input_tensors = [a, z]
del a, z del a, z
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload( z = self.extra_msa_stack._forward_offload(
input_tensors, input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del input_tensors del input_tensors
else: else:
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -390,6 +394,7 @@ class AlphaFold(nn.Module): ...@@ -390,6 +394,7 @@ class AlphaFold(nn.Module):
a, z, a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -400,7 +405,7 @@ class AlphaFold(nn.Module): ...@@ -400,7 +405,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
if(self.globals.offload_inference): if self.globals.offload_inference:
input_tensors = [m, z] input_tensors = [m, z]
del m, z del m, z
m, z, s = self.evoformer._forward_offload( m, z, s = self.evoformer._forward_offload(
...@@ -408,10 +413,11 @@ class AlphaFold(nn.Module): ...@@ -408,10 +413,11 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype), msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del input_tensors del input_tensors
else: else:
m, z, s = self.evoformer( m, z, s = self.evoformer(
...@@ -420,6 +426,7 @@ class AlphaFold(nn.Module): ...@@ -420,6 +426,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash, use_flash=self.globals.use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -456,7 +463,7 @@ class AlphaFold(nn.Module): ...@@ -456,7 +463,7 @@ class AlphaFold(nn.Module):
early_stop = False early_stop = False
if self.globals.is_multimer: if self.globals.is_multimer:
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask, no_batch_dims) early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
del x_prev del x_prev
...@@ -544,7 +551,7 @@ class AlphaFold(nn.Module): ...@@ -544,7 +551,7 @@ class AlphaFold(nn.Module):
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
early_stop = False early_stop = False
num_recycles = 0 num_recycles = 0
for cycle_no in range(num_iters): for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
......
...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module): ...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
biases: Optional[List[torch.Tensor]], biases: Optional[List[torch.Tensor]],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
flash_mask: Optional[torch.Tensor], flash_mask: Optional[torch.Tensor],
...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module): ...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=flash_mask, flash_mask=flash_mask,
...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module): ...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module): ...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m, m,
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module): ...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module): ...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module): ...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m = self._msa_att( m = self._msa_att(
m, m,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
) )
......
...@@ -12,20 +12,22 @@ ...@@ -12,20 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import importlib import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple
import numpy as np import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed): ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed:
import deepspeed import deepspeed
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed): if fa_is_installed:
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
...@@ -33,7 +35,6 @@ import torch.nn as nn ...@@ -33,7 +35,6 @@ import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
...@@ -42,8 +43,8 @@ from openfold.utils.tensor_utils import ( ...@@ -42,8 +43,8 @@ from openfold.utils.tensor_utils import (
) )
DEFAULT_LMA_Q_CHUNK_SIZE=1024 DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096 DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums): def _prod(nums):
...@@ -217,9 +218,9 @@ class LayerNorm(nn.Module): ...@@ -217,9 +218,9 @@ class LayerNorm(nn.Module):
d = x.dtype d = x.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -249,9 +250,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -249,9 +250,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype d = t.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
else: else:
...@@ -283,7 +284,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias ...@@ -283,7 +284,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable( def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint, query, key, value, biases, chunk_size, chunk_dim, checkpoint,
): ):
if(checkpoint and len(biases) > 2): if checkpoint and len(biases) > 2:
raise ValueError( raise ValueError(
"Checkpointed version permits only permits two bias terms" "Checkpointed version permits only permits two bias terms"
) )
...@@ -311,7 +312,7 @@ def _attention_chunked_trainable( ...@@ -311,7 +312,7 @@ def _attention_chunked_trainable(
) )
return b[tuple(idx)] return b[tuple(idx)]
if(checkpoint): if checkpoint:
bias_1_chunk, bias_2_chunk = [ bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None _slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2] for b in (biases + [None, None])[:2]
...@@ -398,7 +399,8 @@ class Attention(nn.Module): ...@@ -398,7 +399,8 @@ class Attention(nn.Module):
def _prep_qkv(self, def _prep_qkv(self,
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor kv_x: torch.Tensor,
apply_scale: bool = True
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor
]: ]:
...@@ -417,7 +419,8 @@ class Attention(nn.Module): ...@@ -417,7 +419,8 @@ class Attention(nn.Module):
k = k.transpose(-2, -3) k = k.transpose(-2, -3)
v = v.transpose(-2, -3) v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden) if apply_scale:
q /= math.sqrt(self.c_hidden)
return q, k, v return q, k, v
...@@ -425,7 +428,7 @@ class Attention(nn.Module): ...@@ -425,7 +428,7 @@ class Attention(nn.Module):
o: torch.Tensor, o: torch.Tensor,
q_x: torch.Tensor q_x: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
if(self.linear_g is not None): if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -446,11 +449,12 @@ class Attention(nn.Module): ...@@ -446,11 +449,12 @@ class Attention(nn.Module):
kv_x: torch.Tensor, kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False, use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None, flash_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -465,6 +469,10 @@ class Attention(nn.Module): ...@@ -465,6 +469,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation "use_<...>" flags are True, a stock PyTorch implementation
is used instead is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma: use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch none of the "use_<...>" flags are True, a stock PyTorch
...@@ -476,50 +484,57 @@ class Attention(nn.Module): ...@@ -476,50 +484,57 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)): if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError( raise ValueError(
"If use_lma is specified, lma_q_chunk_size and " "If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided" "lma_kv_chunk_size must be provided"
) )
if(use_flash and biases is not None): if use_flash and biases is not None:
raise ValueError( raise ValueError(
"use_flash is incompatible with the bias option. For masking, " "use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead" "use flash_mask instead"
) )
attn_options = [use_memory_efficient_kernel, use_lma, use_flash] attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if(sum(attn_options) > 1): if sum(attn_options) > 1:
raise ValueError( raise ValueError(
"Choose at most one alternative attention algorithm" "Choose at most one alternative attention algorithm"
) )
if(biases is None): if biases is None:
biases = [] biases = []
# [*, H, Q/K, C_hidden] # DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention)
# [*, Q, H, C_hidden]
if is_fp16_enabled(): if is_fp16_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if use_memory_efficient_kernel:
if(len(biases) > 2): if len(biases) > 2:
raise ValueError( raise ValueError(
"If use_memory_efficient_kernel is True, you may only " "If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms" "provide up to two bias terms"
) )
o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_lma): elif use_deepspeed_evo_attention:
if len(biases) > 2:
raise ValueError(
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_flash): elif use_flash:
o = _flash_attn(q, k, v, flash_mask) o = _flash_attn(q, k, v, flash_mask)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
...@@ -577,7 +592,7 @@ class GlobalAttention(nn.Module): ...@@ -577,7 +592,7 @@ class GlobalAttention(nn.Module):
v = self.linear_v(m) v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma): if not use_lma:
# [*, N_res, H, N_seq] # [*, N_res, H, N_seq]
a = torch.matmul( a = torch.matmul(
q, q,
...@@ -619,6 +634,72 @@ class GlobalAttention(nn.Module): ...@@ -619,6 +634,72 @@ class GlobalAttention(nn.Module):
return m return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, H, Q, C_hidden] query data
k:
[*, H, K, C_hidden] key data
v:
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma( def _lma(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -683,7 +764,7 @@ def _lma( ...@@ -683,7 +764,7 @@ def _lma(
@torch.jit.ignore @torch.jit.ignore
def _flash_attn(q, k, v, kv_mask): def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed): if not fa_is_installed:
raise ValueError( raise ValueError(
"_flash_attn requires that FlashAttention be installed" "_flash_attn requires that FlashAttention be installed"
) )
...@@ -735,8 +816,8 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -735,8 +816,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens, kv_cu_seqlens,
q_max_s, q_max_s,
kv_max_s, kv_max_s,
dropout_p = 0., dropout_p=0.,
softmax_scale = 1., # q has been scaled already softmax_scale=1., # q has been scaled already
) )
# [*, B, N, H, C] # [*, B, N, H, C]
......
...@@ -20,7 +20,7 @@ from typing import Optional, List ...@@ -20,7 +20,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention from openfold.model.primitives import LayerNorm, Attention
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -48,7 +48,6 @@ from openfold.utils.feats import ( ...@@ -48,7 +48,6 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add, add,
permute_final_dims, permute_final_dims,
flatten_final_dims,
tensor_tree_map, tensor_tree_map,
) )
...@@ -57,6 +56,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -57,6 +56,7 @@ class TemplatePointwiseAttention(nn.Module):
""" """
Implements Algorithm 17. Implements Algorithm 17.
""" """
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
""" """
Args: Args:
...@@ -85,12 +85,12 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -85,12 +85,12 @@ class TemplatePointwiseAttention(nn.Module):
) )
def _chunk(self, def _chunk(self,
z: torch.Tensor, z: torch.Tensor,
t: torch.Tensor, t: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": z, "q_x": z,
"kv_x": t, "kv_x": t,
...@@ -103,15 +103,14 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -103,15 +103,14 @@ class TemplatePointwiseAttention(nn.Module):
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
) )
def forward(self,
def forward(self, t: torch.Tensor,
t: torch.Tensor, z: torch.Tensor,
z: torch.Tensor, template_mask: Optional[torch.Tensor] = None,
template_mask: Optional[torch.Tensor] = None, # This module suffers greatly from a small chunk size
# This module suffers greatly from a small chunk size chunk_size: Optional[int] = 256,
chunk_size: Optional[int] = 256, use_lma: bool = False,
use_lma: bool = False, ) -> torch.Tensor:
) -> torch.Tensor:
""" """
Args: Args:
t: t:
...@@ -212,13 +211,20 @@ class TemplatePairStackBlock(nn.Module): ...@@ -212,13 +211,20 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n, self.pair_transition_n,
) )
def tri_att_start_end(self, single, _attn_chunk_size, single_mask, use_lma, inplace_safe): def tri_att_start_end(self,
single: torch.Tensor,
_attn_chunk_size: Optional[int],
single_mask: torch.Tensor,
use_deepspeed_evo_attention: bool,
use_lma: bool,
inplace_safe: bool):
single = add(single, single = add(single,
self.dropout_row( self.dropout_row(
self.tri_att_start( self.tri_att_start(
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -232,6 +238,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -232,6 +238,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -241,14 +248,17 @@ class TemplatePairStackBlock(nn.Module): ...@@ -241,14 +248,17 @@ class TemplatePairStackBlock(nn.Module):
return single return single
def tri_mul_out_in(self, single, single_mask, inplace_safe): def tri_mul_out_in(self,
single: torch.Tensor,
single_mask: torch.Tensor,
inplace_safe: bool):
tmu_update = self.tri_mul_out( tmu_update = self.tri_mul_out(
single, single,
mask=single_mask, mask=single_mask,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if (not inplace_safe): if not inplace_safe:
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -261,7 +271,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -261,7 +271,7 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if (not inplace_safe): if not inplace_safe:
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -270,16 +280,17 @@ class TemplatePairStackBlock(nn.Module): ...@@ -270,16 +280,17 @@ class TemplatePairStackBlock(nn.Module):
return single return single
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_deepspeed_evo_attention: bool = False,
inplace_safe: bool = False, use_lma: bool = False,
_mask_trans: bool = True, inplace_safe: bool = False,
_attn_chunk_size: Optional[int] = None, _mask_trans: bool = True,
): _attn_chunk_size: Optional[int] = None,
if(_attn_chunk_size is None): ):
if _attn_chunk_size is None:
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
single_templates = [ single_templates = [
...@@ -299,16 +310,19 @@ class TemplatePairStackBlock(nn.Module): ...@@ -299,16 +310,19 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe=inplace_safe), inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size, _attn_chunk_size=_attn_chunk_size,
single_mask=single_mask, single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe) inplace_safe=inplace_safe)
else: else:
single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single, single = self.tri_mul_out_in(
_attn_chunk_size=_attn_chunk_size, single=self.tri_att_start_end(single=single,
single_mask=single_mask, _attn_chunk_size=_attn_chunk_size,
use_lma=use_lma, single_mask=single_mask,
inplace_safe=inplace_safe), use_deepspeed_evo_attention=use_deepspeed_evo_attention,
single_mask=single_mask, use_lma=use_lma,
inplace_safe=inplace_safe) inplace_safe=inplace_safe),
single_mask=single_mask,
inplace_safe=inplace_safe)
single = add(single, single = add(single,
self.pair_transition( self.pair_transition(
...@@ -319,10 +333,10 @@ class TemplatePairStackBlock(nn.Module): ...@@ -319,10 +333,10 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe, inplace_safe,
) )
if (not inplace_safe): if not inplace_safe:
single_templates[i] = single single_templates[i] = single
if (not inplace_safe): if not inplace_safe:
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -332,6 +346,7 @@ class TemplatePairStack(nn.Module): ...@@ -332,6 +346,7 @@ class TemplatePairStack(nn.Module):
""" """
Implements Algorithm 16. Implements Algorithm 16.
""" """
def __init__( def __init__(
self, self,
c_t, c_t,
...@@ -389,7 +404,7 @@ class TemplatePairStack(nn.Module): ...@@ -389,7 +404,7 @@ class TemplatePairStack(nn.Module):
self.tune_chunk_size = tune_chunk_size self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None self.chunk_size_tuner = None
if(tune_chunk_size): if tune_chunk_size:
self.chunk_size_tuner = ChunkSizeTuner() self.chunk_size_tuner = ChunkSizeTuner()
def forward( def forward(
...@@ -397,6 +412,7 @@ class TemplatePairStack(nn.Module): ...@@ -397,6 +412,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -410,7 +426,7 @@ class TemplatePairStack(nn.Module): ...@@ -410,7 +426,7 @@ class TemplatePairStack(nn.Module):
Returns: Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update [*, N_templ, N_res, N_res, C_t] template embedding update
""" """
if(mask.shape[-3] == 1): if mask.shape[-3] == 1:
expand_idx = list(mask.shape) expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4] expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx) mask = mask.expand(*expand_idx)
...@@ -420,6 +436,7 @@ class TemplatePairStack(nn.Module): ...@@ -420,6 +436,7 @@ class TemplatePairStack(nn.Module):
b, b,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -427,18 +444,18 @@ class TemplatePairStack(nn.Module): ...@@ -427,18 +444,18 @@ class TemplatePairStack(nn.Module):
for b in self.blocks for b in self.blocks
] ]
if(chunk_size is not None and self.chunk_size_tuner is not None): if chunk_size is not None and self.chunk_size_tuner is not None:
assert(not self.training) assert (not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(t.clone(),), args=(t.clone(),),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [ blocks = [
partial(b, partial(b,
chunk_size=tuned_chunk_size, chunk_size=tuned_chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4), _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks ) for b in blocks
] ]
t, = checkpoint_blocks( t, = checkpoint_blocks(
...@@ -453,11 +470,11 @@ class TemplatePairStack(nn.Module): ...@@ -453,11 +470,11 @@ class TemplatePairStack(nn.Module):
def embed_templates_offload( def embed_templates_offload(
model, model,
batch, batch,
z, z,
pair_mask, pair_mask,
templ_dim, templ_dim,
template_chunk_size=256, template_chunk_size=256,
inplace_safe=False, inplace_safe=False,
): ):
...@@ -508,13 +525,15 @@ def embed_templates_offload( ...@@ -508,13 +525,15 @@ def embed_templates_offload(
# [*, 1, N, N, C_z] # [*, 1, N, N, C_z]
t = model.template_pair_stack( t = model.template_pair_stack(
t.unsqueeze(templ_dim), t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
assert(sys.getrefcount(t) == 2) assert (sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu()) pair_embeds_cpu.append(t.cpu())
...@@ -537,10 +556,10 @@ def embed_templates_offload( ...@@ -537,10 +556,10 @@ def embed_templates_offload(
) )
t[..., i: i + template_chunk_size, :, :] = att_chunk t[..., i: i + template_chunk_size, :, :] = att_chunk
del pair_chunks del pair_chunks
if(inplace_safe): if inplace_safe:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
else: else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0) t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
...@@ -553,7 +572,7 @@ def embed_templates_offload( ...@@ -553,7 +572,7 @@ def embed_templates_offload(
# [*, N, C_m] # [*, N, C_m]
a = model.template_single_embedder(template_angle_feat) a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
...@@ -562,10 +581,10 @@ def embed_templates_offload( ...@@ -562,10 +581,10 @@ def embed_templates_offload(
def embed_templates_average( def embed_templates_average(
model, model,
batch, batch,
z, z,
pair_mask, pair_mask,
templ_dim, templ_dim,
templ_group_size=2, templ_group_size=2,
inplace_safe=False, inplace_safe=False,
...@@ -601,12 +620,12 @@ def embed_templates_average( ...@@ -601,12 +620,12 @@ def embed_templates_average(
n = z.shape[-2] n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
out_tensor = z.new_zeros(z.shape) out_tensor = z.new_zeros(z.shape)
for i in range(0, n_templ, templ_group_size): for i in range(0, n_templ, templ_group_size):
def slice_template_tensor(t): def slice_template_tensor(t):
s = [slice(None) for _ in t.shape] s = [slice(None) for _ in t.shape]
s[templ_dim] = slice(i, i + templ_group_size) s[templ_dim] = slice(i, i + templ_group_size)
return t[s] return t[s]
template_feats = tensor_tree_map( template_feats = tensor_tree_map(
slice_template_tensor, slice_template_tensor,
batch, batch,
...@@ -624,10 +643,12 @@ def embed_templates_average( ...@@ -624,10 +643,12 @@ def embed_templates_average(
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = model.template_pair_embedder(t) t = model.template_pair_embedder(t)
t = model.template_pair_stack( t = model.template_pair_stack(
t, t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
...@@ -639,19 +660,19 @@ def embed_templates_average( ...@@ -639,19 +660,19 @@ def embed_templates_average(
) )
denom = math.ceil(n_templ / templ_group_size) denom = math.ceil(n_templ / templ_group_size)
if(inplace_safe): if inplace_safe:
t /= denom t /= denom
else: else:
t = t / denom t = t / denom
if(inplace_safe): if inplace_safe:
out_tensor += t out_tensor += t
else: else:
out_tensor = out_tensor + t out_tensor = out_tensor + t
del t del t
if(inplace_safe): if inplace_safe:
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0) out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else: else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0) out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
...@@ -664,7 +685,7 @@ def embed_templates_average( ...@@ -664,7 +685,7 @@ def embed_templates_average(
# [*, N, C_m] # [*, N, C_m]
a = model.template_single_embedder(template_angle_feat) a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": out_tensor}) ret.update({"template_pair_embedding": out_tensor})
......
...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module): ...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module): ...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial( partial(
self.mha, self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
), ),
mha_inputs, mha_inputs,
...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module): ...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module): ...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module): ...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x=x, kv_x=x,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
) )
......
...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input): ...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
] ]
verify_arg_order( verify_arg_order(
...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input): ...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
("m", m), ("m", m),
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)), ("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)), ("use_flash", torch.tensor(model.globals.use_flash)),
] ]
...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input): ...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.float()), ("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input): ...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.transpose(-1, -2).float()), ("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
......
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads OpenFold parameters.
#
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aws &> /dev/null ; then
echo "Error: aws could not be found. Please install aws."
exit 1
fi
DOWNLOAD_DIR="${1}/openfold_soloseq_params"
mkdir -p "${DOWNLOAD_DIR}"
aws s3 cp --no-sign-request --region us-east-1 s3://openfold/openfold_soloseq_params/ "${DOWNLOAD_DIR}" --recursive
...@@ -13,6 +13,12 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. ...@@ -13,6 +13,12 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install python setup.py install
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass
# This setting is used to fix a worker assignment issue during data loading # This setting is used to fix a worker assignment issue during data loading
conda env config vars set KMP_AFFINITY=none conda env config vars set KMP_AFFINITY=none
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
...@@ -10,7 +10,6 @@ import numpy as np ...@@ -10,7 +10,6 @@ import numpy as np
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline # Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also # (by default it hogs 90% of GPU memory. This disables that behavior and also
...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" ...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu" os.environ["JAX_PLATFORM_NAME"] = "gpu"
def skip_unless_ds4s_installed():
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
def skip_unless_flash_attn_installed():
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
def alphafold_is_installed(): def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None return importlib.util.find_spec("alphafold") is not None
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from random import randint from random import randint
import torch
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
...@@ -127,3 +128,17 @@ def random_affines_4x4(dim): ...@@ -127,3 +128,17 @@ def random_affines_4x4(dim):
affines[:, 3, 3] = 1 affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4) return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
dtype=torch.float32, requires_grad=False):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=False).cuda()
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
biases = [mask_bias, z_bias]
return q, kv, mask, biases
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import unittest
import numpy as np
import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms
from openfold.model.primitives import (
lecun_normal_init_,
Attention
)
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import tests.compare_utils as compare_utils
from tests.data_utils import random_template_feats, random_attention_inputs
@compare_utils.skip_unless_ds4s_installed()
class TestDeepSpeedKernel(unittest.TestCase):
def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = 2e-2
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
# Change output params init for testing since they are initialized with 'final' init (zeros)
# Otherwise both will just return zero.
with torch.no_grad():
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
if use_flash:
biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n_res)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention_forward(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attn_forward(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def test_ds_kernel_vs_attention_backward(self):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = consts.eps
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden,
requires_grad=True)
attn = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lecun_normal_init_(attn.linear_g.weight)
lecun_normal_init_(attn.linear_o.weight)
def clone(t):
# Create new params, clone values
t = t.clone()
if t.requires_grad:
t.retain_grad()
return t
def init_attn():
# Create new attention object with same initial weights
a_clone = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a_clone.load_state_dict(attn.state_dict())
return a_clone
# Clone param values and run attention with DS kernel
q_repro = clone(q)
kv_repro = clone(kv)
biases_repro = [clone(b) for b in biases]
a_repro = init_attn()
out_repro = a_repro(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
kv_gt = clone(kv)
biases_gt = [clone(b) for b in biases]
# Clone param values and run attention without DS kernel
a_gt = init_attn()
out_gt = a_gt(q_gt, kv_gt, biases=biases_gt)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
# Compare the grads of attention inputs
pairs = zip([q_repro, kv_repro, biases_repro[1]],
[q_gt, kv_gt, biases_gt[1]])
for i, item in enumerate(pairs):
t_repro, t_gt = item
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
# Compare the grads of model weights
a_repro_params = dict(a_repro.named_parameters())
a_gt_params = dict(a_gt.named_parameters())
for name in a_gt_params.keys():
t_repro = a_repro_params[name]
t_gt = a_gt_params[name]
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item {name}: {err}')
def compare_evoformer(self, dtype, eps):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res = 20
n_seq = 18
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
"pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
}
masks = {
"msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
"pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
}
with torch.cuda.amp.autocast(dtype=dtype):
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=False,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
# In practice, layer norms applied later in the network make any
# kernel rounding errors negligible
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=True,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}')
err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
self.assertTrue(err < eps, f'Pair Error {err}')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(dtype=torch.float32, eps=2e-2)
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates.
"""
eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment