Unverified Commit 49ab0539 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #407 from jnwei/pl_upgrades

Pytorch lightning upgrades
parents df4dfacb f0fc7d91
...@@ -5,7 +5,7 @@ jobs: ...@@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-python@v4 - uses: actions/setup-python@v5
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install flake8 - run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
...@@ -9,4 +9,4 @@ dist ...@@ -9,4 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass cutlass/
...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ ...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh && rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
......
...@@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \ ...@@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \ --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite: ...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM} primaryClass={q-bio.BM}
} }
``` ```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "view-in-github", "id": "view-in-github"
"colab_type": "text"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
...@@ -52,25 +51,44 @@ ...@@ -52,25 +51,44 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n" "id": "rowN0bVYLe9n"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n", "#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "#@markdown For multiple sequences, separate sequences with a colon `:`\n",
"input_sequence = 'MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER: MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER' #@param {type:\"string\"}\n",
"\n", "\n",
"#@markdown ### Configure the model ⬇️\n", "#@markdown ### Configure the model ⬇️\n",
"\n", "\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n", "weight_set = 'AlphaFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"model_mode = 'multimer' #@param [\"monomer\", \"multimer\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n", "relax_prediction = True #@param {type:\"boolean\"}\n",
"\n", "\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n", "# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n", "input_sequence = input_sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n", "aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n", "allowed_chars = aatypes.union({':'})\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n", "if not set(input_sequence).issubset(allowed_chars):\n",
"\n", " raise Exception(f'Input sequence contains non-amino acid letters: {set(input_sequence) - allowed_chars}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"if ':' in input_sequence and weight_set != 'AlphaFold':\n",
" raise ValueError('Input sequence is a multimer, must select Alphafold weight set')\n",
"\n",
"import enum\n",
"@enum.unique\n",
"class ModelType(enum.Enum):\n",
" MONOMER = 0\n",
" MULTIMER = 1\n",
"\n",
"model_type_dict = {\n",
" 'monomer': ModelType.MONOMER,\n",
" 'multimer': ModelType.MULTIMER,\n",
"}\n",
"\n",
"model_type = model_type_dict[model_mode]\n",
"print(f'Length of input sequence : {len(input_sequence.replace(\":\", \"\"))}')\n",
"#@markdown After making your selections, execute this cell by pressing the\n", "#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left." "#@markdown *Play* button on the left."
] ]
...@@ -79,17 +97,16 @@ ...@@ -79,17 +97,16 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "woIxeCPygt7K" "id": "woIxeCPygt7K"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"\n", "\n",
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab\n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
"\n", "\n",
"import os, time\n", "import os, time\n",
...@@ -103,10 +120,8 @@ ...@@ -103,10 +120,8 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n",
"\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
...@@ -119,12 +134,12 @@ ...@@ -119,12 +134,12 @@
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n", "\n",
" %shell mkdir -p /content/openfold/openfold/resourcees\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
" \n", "\n",
" commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n", " commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" \n", "\n",
" %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
"\n", "\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)" " print(captured)"
...@@ -134,18 +149,17 @@ ...@@ -134,18 +149,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw" "id": "VzJ5iMjTtoZw"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Download model weights \n", "#@title Download model weights\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"# Define constants\n", "# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n", "GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n", "ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n", "OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n", "ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n", "ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
...@@ -184,17 +198,17 @@ ...@@ -184,17 +198,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP" "id": "_FpxxMo-mvcP"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Import Python packages\n", "#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"import unittest.mock\n", "import unittest.mock\n",
"import sys\n", "import sys\n",
"from typing import Dict, Sequence\n",
"\n", "\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
...@@ -234,22 +248,12 @@ ...@@ -234,22 +248,12 @@
" return \"UTF-8\"\n", " return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n", "locale.getpreferredencoding = getpreferredencoding\n",
"\n", "\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n", "from openfold import config\n",
"from openfold.data import feature_pipeline\n", "from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n", "from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n", "from openfold.data import data_pipeline\n",
"from openfold.data import msa_pairing\n",
"from openfold.data import feature_processing_multimer\n",
"from openfold.data.tools import jackhmmer\n", "from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n", "from openfold.model import model\n",
"from openfold.np import protein\n", "from openfold.np import protein\n",
...@@ -276,22 +280,16 @@ ...@@ -276,22 +280,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title Search against genetic databases\n",
"\n", "\n",
"#@markdown Once this cell has been executed, you will see\n", "#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n", "#@markdown statistics about the multiple sequence alignment\n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n", "#@markdown (MSA) that will be used by OpenFold. In particular,\n",
"#@markdown you’ll see how well each residue is covered by similar \n", "#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n", "#@markdown sequences in the MSA.\n",
"\n", "\n",
"# --- Find the closest source ---\n", "# --- Find the closest source --\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n", "ex = futures.ThreadPoolExecutor(3)\n",
"def fetch(source):\n", "def fetch(source):\n",
...@@ -304,114 +302,156 @@ ...@@ -304,114 +302,156 @@
" ex.shutdown()\n", " ex.shutdown()\n",
" break\n", " break\n",
"\n", "\n",
"# --- Search against genetic databases ---\n",
"with open('target.fasta', 'wt') as f:\n",
" f.write(f'>query\\n{sequence}')\n",
"\n",
"# Run the search against chunks of genetic databases (since the genetic\n", "# Run the search against chunks of genetic databases (since the genetic\n",
"# databases don't fit in Colab ramdisk).\n", "# databases don't fit in Colab ramdisk).\n",
"\n", "\n",
"jackhmmer_binary_path = '/usr/bin/jackhmmer'\n", "jackhmmer_binary_path = '/usr/bin/jackhmmer'\n",
"dbs = []\n",
"\n", "\n",
"num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n", "# --- Parse multiple sequences, if there are any ---\n",
"total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n", "def split_multiple_sequences(sequence):\n",
" seqs = sequence.split(':')\n",
" sorted_seqs = sorted(seqs, key=lambda s: len(s))\n",
"\n",
" # TODO: Handle the homomer case when writing fasta sequences\n",
" fasta_path_tuples = []\n",
" for idx, seq in enumerate(set(sorted_seqs)):\n",
" fasta_path = f'target_{idx+1}.fasta'\n",
" with open(fasta_path, 'wt') as f:\n",
" f.write(f'>query\\n{seq}\\n')\n",
" fasta_path_tuples.append((seq, fasta_path))\n",
" fasta_path_by_seq = dict(fasta_path_tuples)\n",
"\n",
" return sorted_seqs, fasta_path_by_seq\n",
"\n",
"sequences, fasta_path_by_sequence = split_multiple_sequences(input_sequence)\n",
"db_results_by_sequence = {seq: {} for seq in fasta_path_by_sequence.keys()}\n",
"\n",
"DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'\n",
"db_configs = {}\n",
"db_configs['smallbfd'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2021_03.fasta',\n",
" 'z_value': 65984053,\n",
" 'num_jackhmmer_chunks': 17,\n",
"}\n",
"db_configs['mgnify'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}mgy_clusters_2022_05.fasta',\n",
" 'z_value': 304820129,\n",
" 'num_jackhmmer_chunks': 120,\n",
"}\n",
"db_configs['uniref90'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2022_01.fasta',\n",
" 'z_value': 144113457,\n",
" 'num_jackhmmer_chunks': 62,\n",
"}\n",
"\n",
"# Search UniProt and construct the all_seq features only for heteromers, not homomers.\n",
"if model_type == ModelType.MULTIMER and len(set(sequences)) > 1:\n",
" db_configs['uniprot'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniprot_2021_04.fasta',\n",
" 'z_value': 225013025 + 565928,\n",
" 'num_jackhmmer_chunks': 101,\n",
" }\n",
"\n",
"total_jackhmmer_chunks = sum([d['num_jackhmmer_chunks'] for d in db_configs.values()])\n",
"with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" def jackhmmer_chunk_callback(i):\n", " def jackhmmer_chunk_callback(i):\n",
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" pbar.set_description('Searching uniref90')\n", " for db_name, db_config in db_configs.items():\n",
" jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n", " pbar.set_description(f'Searching {db_name}')\n",
" binary_path=jackhmmer_binary_path,\n", " jackhmmer_runner = jackhmmer.Jackhmmer(\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n", " binary_path=jackhmmer_binary_path,\n",
" get_tblout=True,\n", " database_path=db_config['database_path'],\n",
" num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n", " get_tblout=True,\n",
" streaming_callback=jackhmmer_chunk_callback,\n", " num_streamed_chunks=db_config['num_jackhmmer_chunks'],\n",
" z_value=135301051)\n", " streaming_callback=jackhmmer_chunk_callback,\n",
" dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n", " z_value=db_config['z_value'])\n",
"\n", "\n",
" pbar.set_description('Searching smallbfd')\n", " db_results = jackhmmer_runner.query_multiple(fasta_path_by_sequence.values())\n",
" jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n", " for seq, result in zip(fasta_path_by_sequence.keys(), db_results):\n",
" binary_path=jackhmmer_binary_path,\n", " db_results_by_sequence[seq][db_name] = result\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=65984053)\n",
" dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching mgnify')\n",
" jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=304820129)\n",
" dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n",
"\n", "\n",
"\n", "\n",
"# --- Extract the MSAs and visualize ---\n", "# --- Extract the MSAs and visualize ---\n",
"# Extract the MSAs from the Stockholm files.\n", "# Extract the MSAs from the Stockholm files.\n",
"# NB: deduplication happens later in data_pipeline.make_msa_features.\n", "# NB: deduplication happens later in data_pipeline.make_msa_features.\n",
"\n", "\n",
"mgnify_max_hits = 501\n", "MAX_HITS_BY_DB = {\n",
"\n", " 'uniref90': 10000,\n",
"msas = []\n", " 'smallbfd': 5000,\n",
"deletion_matrices = []\n", " 'mgnify': 501,\n",
"full_msa = []\n", " 'uniprot': 50000,\n",
"for db_name, db_results in dbs:\n", "}\n",
" unsorted_results = []\n", "\n",
" for i, result in enumerate(db_results):\n", "msas_by_seq_by_db = {seq: {} for seq in sequences}\n",
" msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n", "full_msa_by_seq = {seq: [] for seq in sequences}\n",
" e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n", "\n",
" e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n", "for seq, sequence_result in db_results_by_sequence.items():\n",
" zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n", " print(f'parsing_results_for_sequence {seq}')\n",
" if i != 0:\n", " for db_name, db_results in sequence_result.items():\n",
" # Only take query from the first chunk\n", " unsorted_results = []\n",
" zipped_results = [x for x in zipped_results if x[2] != 'query']\n", " for i, result in enumerate(db_results):\n",
" unsorted_results.extend(zipped_results)\n", " msa_obj = parsers.parse_stockholm(result['sto'])\n",
" sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n", " e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n",
" db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n", " target_names = msa_obj.descriptions\n",
" if db_msas:\n", " e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n",
" if db_name == 'mgnify':\n", " zipped_results = zip(msa_obj.sequences, msa_obj.deletion_matrix, target_names, e_values)\n",
" db_msas = db_msas[:mgnify_max_hits]\n", " if i != 0:\n",
" db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n", " # Only take query from the first chunk\n",
" full_msa.extend(db_msas)\n", " zipped_results = [x for x in zipped_results if x[2] != 'query']\n",
" msas.append(db_msas)\n", " unsorted_results.extend(zipped_results)\n",
" deletion_matrices.append(db_deletion_matrices)\n", " sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n",
" msa_size = len(set(db_msas))\n", " msas, del_matrix, targets, _ = zip(*sorted_by_evalue)\n",
" print(f'{msa_size} Sequences Found in {db_name}')\n", " db_msas = parsers.Msa(msas, del_matrix, targets)\n",
"\n", " if db_msas:\n",
"deduped_full_msa = list(dict.fromkeys(full_msa))\n", " if db_name in MAX_HITS_BY_DB:\n",
"total_msa_size = len(deduped_full_msa)\n", " db_msas.truncate(MAX_HITS_BY_DB[db_name])\n",
"print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n", " msas_by_seq_by_db[seq][db_name] = db_msas\n",
"\n", " full_msa_by_seq[seq].extend(db_msas.sequences)\n",
"aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n", " msa_size = len(set(db_msas.sequences))\n",
"msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n", " print(f'{msa_size} Sequences Found in {db_name}')\n",
"num_alignments, num_res = msa_arr.shape\n", "\n",
"\n", "\n",
"fig = plt.figure(figsize=(12, 3))\n", "fig = plt.figure(figsize=(12, 3))\n",
"max_num_alignments = 0\n",
"\n",
"for seq_idx, seq in enumerate(set(sequences)):\n",
" full_msas = full_msa_by_seq[seq]\n",
" deduped_full_msa = list(dict.fromkeys(full_msas))\n",
" total_msa_size = len(deduped_full_msa)\n",
" print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
" aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
" msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
" num_alignments, num_res = msa_arr.shape\n",
" plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), label=f'Chain {seq_idx}')\n",
" max_num_alignments = max(num_alignments, max_num_alignments)\n",
"\n",
"\n",
"plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n", "plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n",
"plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n",
"plt.ylabel('Non-Gap Count')\n", "plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n", "plt.yticks(range(0, max_num_alignments + 1, max(1, int(max_num_alignments / 3))))\n",
"plt.legend()\n",
"plt.show()" "plt.show()"
] ],
"metadata": {
"id": "o7BqQN_gfYtq"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2" "id": "XUo6foMQxwS2"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Run OpenFold and download prediction\n", "#@title Run OpenFold and download prediction\n",
"\n", "\n",
"#@markdown Once this cell has been executed, a zip-archive with \n", "#@markdown Once this cell has been executed, a zip-archive with\n",
"#@markdown the obtained prediction will be automatically downloaded \n", "#@markdown the obtained prediction will be automatically downloaded\n",
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# Color bands for visualizing plddt\n", "# Color bands for visualizing plddt\n",
...@@ -423,13 +463,22 @@ ...@@ -423,13 +463,22 @@
"]\n", "]\n",
"\n", "\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = [ \n", "if model_type == ModelType.MONOMER:\n",
" 'finetuning_3.pt', \n", " model_names = [\n",
" 'finetuning_4.pt', \n", " 'finetuning_3.pt',\n",
" 'finetuning_5.pt', \n", " 'finetuning_4.pt',\n",
" 'finetuning_ptm_2.pt',\n", " 'finetuning_5.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n", " 'finetuning_ptm_2.pt',\n",
"]\n", " 'finetuning_no_templ_ptm_1.pt'\n",
" ]\n",
"elif model_type == ModelType.MULTIMER:\n",
" model_names = [\n",
" 'model_1_multimer_v3',\n",
" 'model_2_multimer_v3',\n",
" 'model_3_multimer_v3',\n",
" 'model_4_multimer_v3',\n",
" 'model_5_multimer_v3',\n",
" ]\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
...@@ -440,26 +489,72 @@ ...@@ -440,26 +489,72 @@
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n", " 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n", " }\n",
"\n", "\n",
"\n",
"def make_features(\n",
" sequences: Sequence[str],\n",
" msas_by_seq_by_db: Dict[str, Dict[str, parsers.Msa]],\n",
" model_type: ModelType):\n",
" num_templates = 1 # Placeholder for generating fake template features\n",
" feature_dict = {}\n",
"\n",
" for idx, seq in enumerate(sequences, start=1):\n",
" _chain_id = f'chain_{idx}'\n",
" num_res = len(seq)\n",
"\n",
" feats = data_pipeline.make_sequence_features(seq, _chain_id, num_res)\n",
" msas_without_uniprot = [msas_by_seq_by_db[seq][db] for db in db_configs.keys() if db != 'uniprot']\n",
" msa_feats = data_pipeline.make_msa_features(msas_without_uniprot)\n",
" feats.update(msa_feats)\n",
" feats.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" feature_dict[seq] = feats\n",
" if model_type == ModelType.MULTIMER:\n",
" # Perform extra pair processing steps for heteromers\n",
" if len(set(sequences)) > 1:\n",
" uniprot_msa = msas_by_seq_by_db[seq]['uniprot']\n",
" uniprot_msa_features = data_pipeline.make_msa_features([uniprot_msa])\n",
" valid_feat_names = msa_pairing.MSA_FEATURES + (\n",
" 'msa_species_identifiers',\n",
" )\n",
" pair_feats = {\n",
" f'{k}_all_seq': v for k, v in uniprot_msa_features.items()\n",
" if k in valid_feat_names\n",
" }\n",
" feats.update(pair_feats)\n",
"\n",
" feats = data_pipeline.convert_monomer_features(feats, _chain_id)\n",
" feature_dict[_chain_id] = feats\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" np_example = feature_dict[sequences[0]]\n",
" elif model_type == ModelType.MULTIMER:\n",
" all_chain_feats = data_pipeline.add_assembly_features(feature_dict)\n",
" features = feature_processing_multimer.pair_and_merge(all_chain_features=all_chain_feats)\n",
" np_example = data_pipeline.pad_msa(features, 512)\n",
"\n",
" return np_example\n",
"\n",
"\n",
"output_dir = 'prediction'\n", "output_dir = 'prediction'\n",
"os.makedirs(output_dir, exist_ok=True)\n", "os.makedirs(output_dir, exist_ok=True)\n",
"\n", "\n",
"plddts = {}\n", "plddts = {}\n",
"pae_outputs = {}\n", "pae_outputs = {}\n",
"weighted_ptms = {}\n",
"unrelaxed_proteins = {}\n", "unrelaxed_proteins = {}\n",
"\n", "\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names), bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in list(enumerate(model_names)):\n", " for i, model_name in enumerate(model_names, start = 1):\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n", "\n",
" num_res = len(sequence)\n", " feature_dict = make_features(sequences, msas_by_seq_by_db, model_type)\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n", "\n",
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n", " if model_type == ModelType.MONOMER:\n",
" config_preset = f\"model_{i}\"\n",
" elif model_type == ModelType.MULTIMER:\n",
" config_preset = f'model_{i}_multimer_v3'\n",
" else:\n", " else:\n",
" if(\"_no_templ_\" in model_name):\n", " if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n", " config_preset = \"model_3\"\n",
...@@ -469,6 +564,11 @@ ...@@ -469,6 +564,11 @@
" config_preset += \"_ptm\"\n", " config_preset += \"_ptm\"\n",
"\n", "\n",
" cfg = config.model_config(config_preset)\n", " cfg = config.model_config(config_preset)\n",
"\n",
" # Force the model to only use 3 recycling updates\n",
" cfg.data.common.max_recycling_iters = 3\n",
" cfg.model.recycle_early_stop_tolerance = -1\n",
"\n",
" openfold_model = model.AlphaFold(cfg)\n", " openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n", " openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
...@@ -490,7 +590,9 @@ ...@@ -490,7 +590,9 @@
"\n", "\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n", " pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n", " processed_feature_dict = pipeline.process_features(\n",
" feature_dict, mode='predict'\n", " feature_dict,\n",
" mode='predict',\n",
" is_multimer = (model_type == ModelType.MULTIMER),\n",
" )\n", " )\n",
"\n", "\n",
" processed_feature_dict = tensor_tree_map(\n", " processed_feature_dict = tensor_tree_map(\n",
...@@ -510,21 +612,32 @@ ...@@ -510,21 +612,32 @@
"\n", "\n",
" mean_plddt = prediction_result['plddt'].mean()\n", " mean_plddt = prediction_result['plddt'].mean()\n",
"\n", "\n",
" if 'predicted_aligned_error' in prediction_result:\n", " if model_type == ModelType.MONOMER:\n",
" pae_outputs[model_name] = (\n", " if 'predicted_aligned_error' in prediction_result:\n",
" prediction_result['predicted_aligned_error'],\n", " pae_outputs[model_name] = (\n",
" prediction_result['max_predicted_aligned_error']\n", " prediction_result['predicted_aligned_error'],\n",
" )\n", " prediction_result['max_predicted_aligned_error']\n",
" else:\n", " )\n",
" # Get the pLDDT confidence metrics. Do not put pTM models here as they\n", " else:\n",
" # should never get selected.\n", " # Get the pLDDT confidence metrics. Do not put pTM models here as they\n",
" # should never get selected.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" elif model_type == ModelType.MULTIMER:\n",
" # Multimer models are sorted by pTM+ipTM.\n",
" plddts[model_name] = prediction_result['plddt']\n", " plddts[model_name] = prediction_result['plddt']\n",
" pae_outputs[model_name] = (prediction_result['predicted_aligned_error'],\n",
" prediction_result['max_predicted_aligned_error'])\n",
"\n",
" weighted_ptms[model_name] = prediction_result['weighted_ptm_score']\n",
"\n", "\n",
" # Set the b-factors to the per-residue plddt.\n", " # Set the b-factors to the per-residue plddt.\n",
" final_atom_mask = prediction_result['final_atom_mask']\n", " final_atom_mask = prediction_result['final_atom_mask']\n",
" b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n", " b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n",
" unrelaxed_protein = protein.from_prediction(\n", " unrelaxed_protein = protein.from_prediction(\n",
" processed_feature_dict, prediction_result, b_factors=b_factors\n", " processed_feature_dict,\n",
" prediction_result,\n",
" remove_leading_feature_dimension=False,\n",
" b_factors=b_factors,\n",
" )\n", " )\n",
" unrelaxed_proteins[model_name] = unrelaxed_protein\n", " unrelaxed_proteins[model_name] = unrelaxed_protein\n",
"\n", "\n",
...@@ -535,7 +648,10 @@ ...@@ -535,7 +648,10 @@
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" # Find the best model according to the mean pLDDT.\n", " # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n", " if model_type == ModelType.MONOMER:\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" elif model_type == ModelType.MULTIMER:\n",
" best_model_name = max(weighted_ptms.keys(), key=lambda x: weighted_ptms[x])\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n", " best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n", "\n",
" # --- AMBER relax the best model ---\n", " # --- AMBER relax the best model ---\n",
...@@ -547,7 +663,7 @@ ...@@ -547,7 +663,7 @@
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=20,\n", " max_outer_iterations=20,\n",
" use_gpu=False,\n", " use_gpu=True,\n",
" )\n", " )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n", " relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n", " prot=unrelaxed_proteins[best_model_name]\n",
...@@ -598,6 +714,15 @@ ...@@ -598,6 +714,15 @@
" plt.title('Model Confidence', fontsize=20, pad=20)\n", " plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n", " return plt\n",
"\n", "\n",
"# Show the structure coloured by chain if the multimer model has been used.\n",
"if model_type == ModelType.MULTIMER:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n", "# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n", "color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n", "view = py3Dmol.view(width=800, height=600)\n",
...@@ -643,6 +768,15 @@ ...@@ -643,6 +768,15 @@
" pae, max_pae = list(pae_outputs.values())[0]\n", " pae, max_pae = list(pae_outputs.values())[0]\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n", " plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n", " plt.colorbar(fraction=0.046, pad=0.04)\n",
"\n",
" # Display lines at chain boundaries.\n",
" best_unrelaxed_prot = unrelaxed_proteins[best_model_name]\n",
" total_num_res = best_unrelaxed_prot.residue_index.shape[-1]\n",
" chain_ids = best_unrelaxed_prot.chain_index\n",
" for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
" if chain_boundary.size:\n",
" plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
" plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
" plt.title('Predicted Aligned Error')\n", " plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n", " plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n", " plt.ylabel('Aligned residue')\n",
...@@ -680,7 +814,7 @@ ...@@ -680,7 +814,7 @@
"source": [ "source": [
"### Interpreting the prediction\n", "### Interpreting the prediction\n",
"\n", "\n",
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions." "Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions. More information about the predictions of the AlphaFold Multimer model can be found in the [Alphafold Multimer paper](https://www.biorxiv.org/content/10.1101/2022.03.11.484043v3.full.pdf)."
] ]
}, },
{ {
...@@ -718,7 +852,7 @@ ...@@ -718,7 +852,7 @@
" * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n", " * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n",
" * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", " * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n", " * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n", " * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs.\n",
"* Does this tool install anything on my computer?\n", "* Does this tool install anything on my computer?\n",
" * No, everything happens in the cloud on Google Colab.\n", " * No, everything happens in the cloud on Google Colab.\n",
" * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n", " * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n",
...@@ -766,13 +900,10 @@ ...@@ -766,13 +900,10 @@
} }
], ],
"metadata": { "metadata": {
"accelerator": "GPU",
"colab": { "colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [], "provenance": [],
"gpuType": "T4", "gpuType": "T4",
"include_colab_link": true "toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
...@@ -780,8 +911,9 @@ ...@@ -780,8 +911,9 @@
}, },
"language_info": { "language_info": {
"name": "python" "name": "python"
} },
"accelerator": "GPU"
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
...@@ -3,15 +3,15 @@ channels: ...@@ -3,15 +3,15 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
dependencies: dependencies:
- conda-forge::openmm=7.5.1 - openmm=7.7
- conda-forge::pdbfixer - pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pip: - pip:
- biopython==1.79 - biopython==1.79
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
...@@ -21,14 +21,11 @@ import dataclasses ...@@ -21,14 +21,11 @@ import dataclasses
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np import numpy as np
import torch import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
...@@ -704,10 +701,10 @@ class DataPipeline: ...@@ -704,10 +701,10 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(alignment_index is not None): if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
...@@ -718,14 +715,14 @@ class DataPipeline: ...@@ -718,14 +715,14 @@ class DataPipeline:
for (name, start, size) in alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name) filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if ext == ".a3m":
msa = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
# The "hmm_output" exception is a crude way to exclude # The "hmm_output" exception is a crude way to exclude
# multimer template hits. # multimer template hits.
# Multimer "uniprot_hits" processed separately. # Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]): elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size)) msa = parsers.parse_stockholm(read_msa(start, size))
else: else:
continue continue
...@@ -734,13 +731,22 @@ class DataPipeline: ...@@ -734,13 +731,22 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes for f in os.listdir(alignment_dir):
current_directory = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(alignment_dir, f)
cmd = f"{current_directory}/tools/parse_msa_files.py" filename, ext = os.path.splitext(f)
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip() if ext == ".a3m":
msa_data = pickle.load((open(msa_data_path,'rb'))) with open(path, "r") as fp:
os.remove(msa_data_path) msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data return msa_data
......
...@@ -101,8 +101,8 @@ def empty_template_feats(n_res): ...@@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros( "template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32 (0, n_res, residue_constants.atom_type_num, 3), np.float32
), ),
"template_domain_names": np.array([''.encode()], dtype=np.object), "template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=np.object), "template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32), "template_sum_probs": np.zeros((0, 1), dtype=np.float32),
} }
......
...@@ -90,15 +90,15 @@ def get_optimal_transform( ...@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch, input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has. Select the subunit that is less
one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
Args: Args:
batch: in this funtion batch is the full ground truth features batch: in this function batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features input_asym_id: A list of asym_ids that are in the cropped input features
Return: Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values()) min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length # If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities]) max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length] least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
...@@ -123,7 +123,7 @@ def parse_fasta(data): ...@@ -123,7 +123,7 @@ def parse_fasta(data):
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs return tags, seqs
......
...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join( local_alignment_dir = os.path.join(alignment_dir, tag)
alignment_dir,
os.path.join(alignment_dir, tag),
)
if args.use_precomputed_alignments is None: if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
......
...@@ -113,10 +113,10 @@ else: ...@@ -113,10 +113,10 @@ else:
setup( setup(
name='openfold', name='openfold',
version='1.0.1', version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind', author='OpenFold Team',
author_email='gahdritz@gmail.com', author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0', license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold', url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]), packages=find_packages(exclude=["tests", "scripts"]),
......
import ml_collections as mlc import ml_collections as mlc
consts = mlc.ConfigDict(
monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
multimer_consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3 "model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True "is_multimer": True, # monomer: False, multimer: True
...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict( ...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict(
} }
) )
consts = monomer_consts
config = mlc.ConfigDict( config = mlc.ConfigDict(
{ {
"data": { "data": {
......
...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = { template_feats = {
k: v for k, v in batch.items() if k.startswith("template_") k: v for k, v in batch.items() if k.startswith("template_")
...@@ -309,7 +306,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -309,7 +306,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32) batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch) data_transforms.atom37_to_torsion_angles("template_")(batch)
......
...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym ...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels) merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase): class TestPermutation(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase): ...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2]) anchor_gt_asym = int(anchor_gt_asym)
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5]) anchor_pred_asym = {int(i) for i in anchor_pred_asym}
self.assertIn(int(anchor_pred_asym), [1, 2]) expected_anchors = {1, 2}
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5]) expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self): def test_2_permutation_pentamer(self):
batch = { batch = {
...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase): ...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome) self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome) self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self): def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325 nres_pad = 325 - 57 # suppose the cropping size is 325
batch = { batch = {
......
...@@ -235,6 +235,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -235,6 +235,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
last_epoch=self.last_lr_step
) )
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment