"launch/tio/src/input/text.rs" did not exist on "2fd6592f2d080a571b84f48c69bee9c23e4c9cc9"
Commit 9236c1e3 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into deepspeed-evo-attention

parents 1271a03f 2dc080ce
...@@ -115,7 +115,13 @@ the model, consult `run_pretrained_openfold.py`. ...@@ -115,7 +115,13 @@ the model, consult `run_pretrained_openfold.py`.
### Inference ### Inference
To run inference on a sequence or multiple sequences using a set of DeepMind's To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.: pretrained parameters, first download the OpenFold weights e.g.:
```bash
bash scripts/download_openfold_params.sh openfold/resources
```
then run e.g.:
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
...@@ -226,6 +232,14 @@ python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/ ...@@ -226,6 +232,14 @@ python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command. In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
Then download the SoloSeq model weights, e.g.:
```bash
bash scripts/download_openfold_soloseq_params.sh openfold/resources
```
Now, you are ready to run inference: Now, you are ready to run inference:
```bash ```bash
python run_pretrained_openfold.py \ python run_pretrained_openfold.py \
...@@ -235,7 +249,7 @@ python run_pretrained_openfold.py \ ...@@ -235,7 +249,7 @@ python run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
``` ```
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure. For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
......
{ {
"nbformat": 4, "cells": [
"nbformat_minor": 0, {
"cell_type": "markdown",
"metadata": { "metadata": {
"accelerator": "GPU", "id": "view-in-github",
"colab": { "colab_type": "text"
"name": "OpenFold.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}, },
"language_info": { "source": [
"name": "python" "<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
} ]
}, },
"cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -57,10 +50,12 @@ ...@@ -57,10 +50,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "rowN0bVYLe9n", "cellView": "form",
"cellView": "form" "id": "rowN0bVYLe9n"
}, },
"outputs": [],
"source": [ "source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n", "#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
...@@ -78,16 +73,16 @@ ...@@ -78,16 +73,16 @@
"\n", "\n",
"#@markdown After making your selections, execute this cell by pressing the\n", "#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left." "#@markdown *Play* button on the left."
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "woIxeCPygt7K", "cellView": "form",
"cellView": "form" "id": "woIxeCPygt7K"
}, },
"outputs": [],
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
...@@ -97,75 +92,54 @@ ...@@ -97,75 +92,54 @@
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
"\n", "\n",
"import sys\n", "import os, time\n",
"from IPython.utils import io\n", "from IPython.utils import io\n",
"import os\n", "from sys import version_info\n",
"import subprocess\n", "import subprocess\n",
"import tqdm.notebook\n",
"\n", "\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", "python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n",
"\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n",
"\n",
"\n", "\n",
"python_version = '.'.join(sys.version.split('.')[:2]) #get string like \"3.9\"\n", "os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" %shell sudo apt install --quiet --yes hmmer\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
"\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q conda==4.13.0\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python={python_version} \\\n",
" openmm=7.7.0 \\\n",
" pdbfixer \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79 \\\n",
" modelcif==0.7\n",
"\n", "\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo apt install --quiet --yes hmmer\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
"\n", "\n",
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n", "\n",
" # Install AWS CLI\n", " %shell mkdir -p /content/openfold/openfold/resourcees\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n", " \n",
" %shell unzip -qq awscliv2.zip\n", " commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n",
" %shell sudo ./aws/install\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" %shell rm awscliv2.zip\n", " \n",
" %shell rm -rf ./aws\n", " %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n",
"\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)"
" raise" ]
],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "VzJ5iMjTtoZw", "cellView": "form",
"cellView": "form" "id": "VzJ5iMjTtoZw"
}, },
"outputs": [],
"source": [ "source": [
"#@title Install OpenFold\n", "#@title Download model weights \n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
...@@ -180,13 +154,6 @@ ...@@ -180,13 +154,6 @@
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n",
" %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(weight_set == 'AlphaFold'):\n", " if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n", " %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
...@@ -194,7 +161,14 @@ ...@@ -194,7 +161,14 @@
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n", " --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n", " %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
" elif(weight_set == 'OpenFold'):\n", " elif(weight_set == 'OpenFold'):\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
"\n",
" %shell aws s3 cp \\\n", " %shell aws s3 cp \\\n",
" --no-sign-request \\\n", " --no-sign-request \\\n",
" --region us-east-1 \\\n", " --region us-east-1 \\\n",
...@@ -203,14 +177,17 @@ ...@@ -203,14 +177,17 @@
" else:\n", " else:\n",
" raise ValueError(\"Invalid weight set\")\n", " raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)"
" raise" ]
],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"outputs": [],
"source": [ "source": [
"#@title Import Python packages\n", "#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
...@@ -219,8 +196,8 @@ ...@@ -219,8 +196,8 @@
"import unittest.mock\n", "import unittest.mock\n",
"import sys\n", "import sys\n",
"\n", "\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
"sys.path.append(f'/opt/conda/lib/python{python_version}/site-packages')\n",
"\n", "\n",
"# Allows us to skip installing these packages\n", "# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n", "unnecessary_modules = [\n",
...@@ -245,6 +222,10 @@ ...@@ -245,6 +222,10 @@
"import py3Dmol\n", "import py3Dmol\n",
"import torch\n", "import torch\n",
"import shutil\n", "import shutil\n",
"import tqdm\n",
"import tqdm.notebook\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n", "\n",
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n", "# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n", "# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
...@@ -280,13 +261,7 @@ ...@@ -280,13 +261,7 @@
"from IPython import display\n", "from IPython import display\n",
"from ipywidgets import GridspecLayout\n", "from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output" "from ipywidgets import Output"
], ]
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -301,10 +276,12 @@ ...@@ -301,10 +276,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "2tTeTTsLKPjB", "cellView": "form",
"cellView": "form" "id": "2tTeTTsLKPjB"
}, },
"outputs": [],
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title Search against genetic databases\n",
"\n", "\n",
...@@ -420,16 +397,16 @@ ...@@ -420,16 +397,16 @@
"plt.ylabel('Non-Gap Count')\n", "plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n", "plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.show()" "plt.show()"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "XUo6foMQxwS2", "cellView": "form",
"cellView": "form" "id": "XUo6foMQxwS2"
}, },
"outputs": [],
"source": [ "source": [
"#@title Run OpenFold and download prediction\n", "#@title Run OpenFold and download prediction\n",
"\n", "\n",
...@@ -693,9 +670,7 @@ ...@@ -693,9 +670,7 @@
"# --- Download the predictions ---\n", "# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n", "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')" "files.download(f'{output_dir}.zip')"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -789,5 +764,24 @@ ...@@ -789,5 +764,24 @@
"* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details." "* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details."
] ]
} }
] ],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }
...@@ -165,10 +165,12 @@ def model_config( ...@@ -165,10 +165,12 @@ def model_config(
elif name == "seqemb_initial_training": elif name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1 c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning": elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1 c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
...@@ -321,6 +323,11 @@ config = mlc.ConfigDict( ...@@ -321,6 +323,11 @@ config = mlc.ConfigDict(
"true_msa": [NUM_MSA_SEQ, NUM_RES], "true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [], "use_clamped_fape": [],
}, },
"block_delete_msa": {
"msa_fraction_per_block": 0.3,
"randomize_num_blocks": False,
"num_blocks": 5,
},
"masked_msa": { "masked_msa": {
"profile_prob": 0.1, "profile_prob": 0.1,
"same_prob": 0.1, "same_prob": 0.1,
...@@ -365,6 +372,7 @@ config = mlc.ConfigDict( ...@@ -365,6 +372,7 @@ config = mlc.ConfigDict(
"predict": { "predict": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512, "max_msa_clusters": 512,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -378,6 +386,7 @@ config = mlc.ConfigDict( ...@@ -378,6 +386,7 @@ config = mlc.ConfigDict(
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -391,6 +400,7 @@ config = mlc.ConfigDict( ...@@ -391,6 +400,7 @@ config = mlc.ConfigDict(
"train": { "train": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": True, "subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
......
...@@ -253,28 +253,33 @@ def block_delete_msa(protein, config): ...@@ -253,28 +253,33 @@ def block_delete_msa(protein, config):
* config.msa_fraction_per_block * config.msa_fraction_per_block
).to(torch.int32) ).to(torch.int32)
if int(block_num_seq) == 0:
return protein
if config.randomize_num_blocks: if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform( nb = int(torch.randint(
0, config.num_blocks + 1 low=0,
).sample() high=config.num_blocks + 1,
size=(1,),
device=protein["msa"].device,
)[0])
else: else:
nb = config.num_blocks nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1) del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] del_indices = torch.unique(torch.reshape(del_blocks, [-1]))
# Make sure we keep the original sequence # Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None])) combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
uniques, counts = combined.unique(return_counts=True) uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1] keep_indices = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
assert int(keep_indices[0]) == 0
for k in MSA_FEATURE_NAMES: for k in MSA_FEATURE_NAMES:
if k in protein: if k in protein:
protein[k] = torch.gather(protein[k], keep_indices) protein[k] = torch.index_select(protein[k], 0, keep_indices)
return protein return protein
......
...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
if mode_cfg.block_delete_msa:
transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg: if "max_distillation_msa_clusters" in mode_cfg:
transforms.append( transforms.append(
data_transforms.sample_msa_distillation( data_transforms.sample_msa_distillation(
......
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads OpenFold parameters.
#
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aws &> /dev/null ; then
echo "Error: aws could not be found. Please install aws."
exit 1
fi
DOWNLOAD_DIR="${1}/openfold_soloseq_params"
mkdir -p "${DOWNLOAD_DIR}"
aws s3 cp --no-sign-request --region us-east-1 s3://openfold/openfold_soloseq_params/ "${DOWNLOAD_DIR}" --recursive
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment