"README_ORIGIN.md" did not exist on "3a2c1480313e559c5f47f9af132a18a6472343cf"
Unverified Commit 49ab0539 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #407 from jnwei/pl_upgrades

Pytorch lightning upgrades
parents df4dfacb f0fc7d91
......@@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
- run: pip install --upgrade pip
- run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
......@@ -9,4 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass
cutlass/
......@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH
......
......@@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
......@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM}
}
```
Any work that cites OpenFold should also cite AlphaFold.
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
......@@ -3,8 +3,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
......@@ -52,25 +51,44 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n"
},
"outputs": [],
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@markdown For multiple sequences, separate sequences with a colon `:`\n",
"input_sequence = 'MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER: MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"weight_set = 'AlphaFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"model_mode = 'multimer' #@param [\"monomer\", \"multimer\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"input_sequence = input_sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"allowed_chars = aatypes.union({':'})\n",
"if not set(input_sequence).issubset(allowed_chars):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(input_sequence) - allowed_chars}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"if ':' in input_sequence and weight_set != 'AlphaFold':\n",
" raise ValueError('Input sequence is a multimer, must select Alphafold weight set')\n",
"\n",
"import enum\n",
"@enum.unique\n",
"class ModelType(enum.Enum):\n",
" MONOMER = 0\n",
" MULTIMER = 1\n",
"\n",
"model_type_dict = {\n",
" 'monomer': ModelType.MONOMER,\n",
" 'multimer': ModelType.MULTIMER,\n",
"}\n",
"\n",
"model_type = model_type_dict[model_mode]\n",
"print(f'Length of input sequence : {len(input_sequence.replace(\":\", \"\"))}')\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
]
......@@ -79,17 +97,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "woIxeCPygt7K"
},
"outputs": [],
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"\n",
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown **Note**: This installs the software on the Colab\n",
"#@markdown notebook in the cloud and not on your computer.\n",
"\n",
"import os, time\n",
......@@ -103,10 +120,8 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n",
"\n",
"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
......@@ -119,12 +134,12 @@
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n",
" %shell mkdir -p /content/openfold/openfold/resourcees\n",
" \n",
" commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
"\n",
" commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" \n",
" %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n",
"\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
"\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)"
......@@ -134,18 +149,17 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw"
},
"outputs": [],
"source": [
"#@title Download model weights \n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@title Download model weights\n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
......@@ -184,17 +198,17 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"outputs": [],
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"import unittest.mock\n",
"import sys\n",
"from typing import Dict, Sequence\n",
"\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
......@@ -234,22 +248,12 @@
" return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data import msa_pairing\n",
"from openfold.data import feature_processing_multimer\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
......@@ -276,22 +280,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
"source": [
"#@title Search against genetic databases\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown statistics about the multiple sequence alignment\n",
"#@markdown (MSA) that will be used by OpenFold. In particular,\n",
"#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n",
"\n",
"# --- Find the closest source ---\n",
"# --- Find the closest source --\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n",
"def fetch(source):\n",
......@@ -304,114 +302,156 @@
" ex.shutdown()\n",
" break\n",
"\n",
"# --- Search against genetic databases ---\n",
"with open('target.fasta', 'wt') as f:\n",
" f.write(f'>query\\n{sequence}')\n",
"\n",
"# Run the search against chunks of genetic databases (since the genetic\n",
"# databases don't fit in Colab ramdisk).\n",
"\n",
"jackhmmer_binary_path = '/usr/bin/jackhmmer'\n",
"dbs = []\n",
"\n",
"num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n",
"total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n",
"# --- Parse multiple sequences, if there are any ---\n",
"def split_multiple_sequences(sequence):\n",
" seqs = sequence.split(':')\n",
" sorted_seqs = sorted(seqs, key=lambda s: len(s))\n",
"\n",
" # TODO: Handle the homomer case when writing fasta sequences\n",
" fasta_path_tuples = []\n",
" for idx, seq in enumerate(set(sorted_seqs)):\n",
" fasta_path = f'target_{idx+1}.fasta'\n",
" with open(fasta_path, 'wt') as f:\n",
" f.write(f'>query\\n{seq}\\n')\n",
" fasta_path_tuples.append((seq, fasta_path))\n",
" fasta_path_by_seq = dict(fasta_path_tuples)\n",
"\n",
" return sorted_seqs, fasta_path_by_seq\n",
"\n",
"sequences, fasta_path_by_sequence = split_multiple_sequences(input_sequence)\n",
"db_results_by_sequence = {seq: {} for seq in fasta_path_by_sequence.keys()}\n",
"\n",
"DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'\n",
"db_configs = {}\n",
"db_configs['smallbfd'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2021_03.fasta',\n",
" 'z_value': 65984053,\n",
" 'num_jackhmmer_chunks': 17,\n",
"}\n",
"db_configs['mgnify'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}mgy_clusters_2022_05.fasta',\n",
" 'z_value': 304820129,\n",
" 'num_jackhmmer_chunks': 120,\n",
"}\n",
"db_configs['uniref90'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2022_01.fasta',\n",
" 'z_value': 144113457,\n",
" 'num_jackhmmer_chunks': 62,\n",
"}\n",
"\n",
"# Search UniProt and construct the all_seq features only for heteromers, not homomers.\n",
"if model_type == ModelType.MULTIMER and len(set(sequences)) > 1:\n",
" db_configs['uniprot'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniprot_2021_04.fasta',\n",
" 'z_value': 225013025 + 565928,\n",
" 'num_jackhmmer_chunks': 101,\n",
" }\n",
"\n",
"total_jackhmmer_chunks = sum([d['num_jackhmmer_chunks'] for d in db_configs.values()])\n",
"with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" def jackhmmer_chunk_callback(i):\n",
" pbar.update(n=1)\n",
"\n",
" pbar.set_description('Searching uniref90')\n",
" jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=135301051)\n",
" dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching smallbfd')\n",
" jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n",
" for db_name, db_config in db_configs.items():\n",
" pbar.set_description(f'Searching {db_name}')\n",
" jackhmmer_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n",
" database_path=db_config['database_path'],\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n",
" num_streamed_chunks=db_config['num_jackhmmer_chunks'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=65984053)\n",
" dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n",
" z_value=db_config['z_value'])\n",
"\n",
" pbar.set_description('Searching mgnify')\n",
" jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=304820129)\n",
" dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n",
" db_results = jackhmmer_runner.query_multiple(fasta_path_by_sequence.values())\n",
" for seq, result in zip(fasta_path_by_sequence.keys(), db_results):\n",
" db_results_by_sequence[seq][db_name] = result\n",
"\n",
"\n",
"# --- Extract the MSAs and visualize ---\n",
"# Extract the MSAs from the Stockholm files.\n",
"# NB: deduplication happens later in data_pipeline.make_msa_features.\n",
"\n",
"mgnify_max_hits = 501\n",
"MAX_HITS_BY_DB = {\n",
" 'uniref90': 10000,\n",
" 'smallbfd': 5000,\n",
" 'mgnify': 501,\n",
" 'uniprot': 50000,\n",
"}\n",
"\n",
"msas_by_seq_by_db = {seq: {} for seq in sequences}\n",
"full_msa_by_seq = {seq: [] for seq in sequences}\n",
"\n",
"msas = []\n",
"deletion_matrices = []\n",
"full_msa = []\n",
"for db_name, db_results in dbs:\n",
"for seq, sequence_result in db_results_by_sequence.items():\n",
" print(f'parsing_results_for_sequence {seq}')\n",
" for db_name, db_results in sequence_result.items():\n",
" unsorted_results = []\n",
" for i, result in enumerate(db_results):\n",
" msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n",
" msa_obj = parsers.parse_stockholm(result['sto'])\n",
" e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n",
" target_names = msa_obj.descriptions\n",
" e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n",
" zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n",
" zipped_results = zip(msa_obj.sequences, msa_obj.deletion_matrix, target_names, e_values)\n",
" if i != 0:\n",
" # Only take query from the first chunk\n",
" zipped_results = [x for x in zipped_results if x[2] != 'query']\n",
" unsorted_results.extend(zipped_results)\n",
" sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n",
" db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n",
" msas, del_matrix, targets, _ = zip(*sorted_by_evalue)\n",
" db_msas = parsers.Msa(msas, del_matrix, targets)\n",
" if db_msas:\n",
" if db_name == 'mgnify':\n",
" db_msas = db_msas[:mgnify_max_hits]\n",
" db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n",
" full_msa.extend(db_msas)\n",
" msas.append(db_msas)\n",
" deletion_matrices.append(db_deletion_matrices)\n",
" msa_size = len(set(db_msas))\n",
" if db_name in MAX_HITS_BY_DB:\n",
" db_msas.truncate(MAX_HITS_BY_DB[db_name])\n",
" msas_by_seq_by_db[seq][db_name] = db_msas\n",
" full_msa_by_seq[seq].extend(db_msas.sequences)\n",
" msa_size = len(set(db_msas.sequences))\n",
" print(f'{msa_size} Sequences Found in {db_name}')\n",
"\n",
"deduped_full_msa = list(dict.fromkeys(full_msa))\n",
"total_msa_size = len(deduped_full_msa)\n",
"print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
"aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
"msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
"num_alignments, num_res = msa_arr.shape\n",
"\n",
"fig = plt.figure(figsize=(12, 3))\n",
"max_num_alignments = 0\n",
"\n",
"for seq_idx, seq in enumerate(set(sequences)):\n",
" full_msas = full_msa_by_seq[seq]\n",
" deduped_full_msa = list(dict.fromkeys(full_msas))\n",
" total_msa_size = len(deduped_full_msa)\n",
" print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
" aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
" msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
" num_alignments, num_res = msa_arr.shape\n",
" plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), label=f'Chain {seq_idx}')\n",
" max_num_alignments = max(num_alignments, max_num_alignments)\n",
"\n",
"\n",
"plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n",
"plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n",
"plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.yticks(range(0, max_num_alignments + 1, max(1, int(max_num_alignments / 3))))\n",
"plt.legend()\n",
"plt.show()"
]
],
"metadata": {
"id": "o7BqQN_gfYtq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2"
},
"outputs": [],
"source": [
"#@title Run OpenFold and download prediction\n",
"\n",
"#@markdown Once this cell has been executed, a zip-archive with \n",
"#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown Once this cell has been executed, a zip-archive with\n",
"#@markdown the obtained prediction will be automatically downloaded\n",
"#@markdown to your computer.\n",
"\n",
"# Color bands for visualizing plddt\n",
......@@ -423,13 +463,22 @@
"]\n",
"\n",
"# --- Run the model ---\n",
"model_names = [ \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
"if model_type == ModelType.MONOMER:\n",
" model_names = [\n",
" 'finetuning_3.pt',\n",
" 'finetuning_4.pt',\n",
" 'finetuning_5.pt',\n",
" 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n",
" ]\n",
"elif model_type == ModelType.MULTIMER:\n",
" model_names = [\n",
" 'model_1_multimer_v3',\n",
" 'model_2_multimer_v3',\n",
" 'model_3_multimer_v3',\n",
" 'model_4_multimer_v3',\n",
" 'model_5_multimer_v3',\n",
" ]\n",
"\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n",
......@@ -440,26 +489,72 @@
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n",
"\n",
"\n",
"def make_features(\n",
" sequences: Sequence[str],\n",
" msas_by_seq_by_db: Dict[str, Dict[str, parsers.Msa]],\n",
" model_type: ModelType):\n",
" num_templates = 1 # Placeholder for generating fake template features\n",
" feature_dict = {}\n",
"\n",
" for idx, seq in enumerate(sequences, start=1):\n",
" _chain_id = f'chain_{idx}'\n",
" num_res = len(seq)\n",
"\n",
" feats = data_pipeline.make_sequence_features(seq, _chain_id, num_res)\n",
" msas_without_uniprot = [msas_by_seq_by_db[seq][db] for db in db_configs.keys() if db != 'uniprot']\n",
" msa_feats = data_pipeline.make_msa_features(msas_without_uniprot)\n",
" feats.update(msa_feats)\n",
" feats.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" feature_dict[seq] = feats\n",
" if model_type == ModelType.MULTIMER:\n",
" # Perform extra pair processing steps for heteromers\n",
" if len(set(sequences)) > 1:\n",
" uniprot_msa = msas_by_seq_by_db[seq]['uniprot']\n",
" uniprot_msa_features = data_pipeline.make_msa_features([uniprot_msa])\n",
" valid_feat_names = msa_pairing.MSA_FEATURES + (\n",
" 'msa_species_identifiers',\n",
" )\n",
" pair_feats = {\n",
" f'{k}_all_seq': v for k, v in uniprot_msa_features.items()\n",
" if k in valid_feat_names\n",
" }\n",
" feats.update(pair_feats)\n",
"\n",
" feats = data_pipeline.convert_monomer_features(feats, _chain_id)\n",
" feature_dict[_chain_id] = feats\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" np_example = feature_dict[sequences[0]]\n",
" elif model_type == ModelType.MULTIMER:\n",
" all_chain_feats = data_pipeline.add_assembly_features(feature_dict)\n",
" features = feature_processing_multimer.pair_and_merge(all_chain_features=all_chain_feats)\n",
" np_example = data_pipeline.pad_msa(features, 512)\n",
"\n",
" return np_example\n",
"\n",
"\n",
"output_dir = 'prediction'\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"plddts = {}\n",
"pae_outputs = {}\n",
"weighted_ptms = {}\n",
"unrelaxed_proteins = {}\n",
"\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in list(enumerate(model_names)):\n",
"with tqdm.notebook.tqdm(total=len(model_names), bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in enumerate(model_names, start = 1):\n",
" pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" feature_dict = make_features(sequences, msas_by_seq_by_db, model_type)\n",
"\n",
" if(weight_set == \"AlphaFold\"):\n",
" if model_type == ModelType.MONOMER:\n",
" config_preset = f\"model_{i}\"\n",
" elif model_type == ModelType.MULTIMER:\n",
" config_preset = f'model_{i}_multimer_v3'\n",
" else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
......@@ -469,6 +564,11 @@
" config_preset += \"_ptm\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
"\n",
" # Force the model to only use 3 recycling updates\n",
" cfg.data.common.max_recycling_iters = 3\n",
" cfg.model.recycle_early_stop_tolerance = -1\n",
"\n",
" openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n",
......@@ -490,7 +590,9 @@
"\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n",
" feature_dict, mode='predict'\n",
" feature_dict,\n",
" mode='predict',\n",
" is_multimer = (model_type == ModelType.MULTIMER),\n",
" )\n",
"\n",
" processed_feature_dict = tensor_tree_map(\n",
......@@ -510,6 +612,7 @@
"\n",
" mean_plddt = prediction_result['plddt'].mean()\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" if 'predicted_aligned_error' in prediction_result:\n",
" pae_outputs[model_name] = (\n",
" prediction_result['predicted_aligned_error'],\n",
......@@ -519,12 +622,22 @@
" # Get the pLDDT confidence metrics. Do not put pTM models here as they\n",
" # should never get selected.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" elif model_type == ModelType.MULTIMER:\n",
" # Multimer models are sorted by pTM+ipTM.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" pae_outputs[model_name] = (prediction_result['predicted_aligned_error'],\n",
" prediction_result['max_predicted_aligned_error'])\n",
"\n",
" weighted_ptms[model_name] = prediction_result['weighted_ptm_score']\n",
"\n",
" # Set the b-factors to the per-residue plddt.\n",
" final_atom_mask = prediction_result['final_atom_mask']\n",
" b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n",
" unrelaxed_protein = protein.from_prediction(\n",
" processed_feature_dict, prediction_result, b_factors=b_factors\n",
" processed_feature_dict,\n",
" prediction_result,\n",
" remove_leading_feature_dimension=False,\n",
" b_factors=b_factors,\n",
" )\n",
" unrelaxed_proteins[model_name] = unrelaxed_protein\n",
"\n",
......@@ -535,7 +648,10 @@
" pbar.update(n=1)\n",
"\n",
" # Find the best model according to the mean pLDDT.\n",
" if model_type == ModelType.MONOMER:\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" elif model_type == ModelType.MULTIMER:\n",
" best_model_name = max(weighted_ptms.keys(), key=lambda x: weighted_ptms[x])\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n",
......@@ -547,7 +663,7 @@
" stiffness=10.0,\n",
" exclude_residues=[],\n",
" max_outer_iterations=20,\n",
" use_gpu=False,\n",
" use_gpu=True,\n",
" )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n",
......@@ -598,6 +714,15 @@
" plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n",
"\n",
"# Show the structure coloured by chain if the multimer model has been used.\n",
"if model_type == ModelType.MULTIMER:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n",
......@@ -643,6 +768,15 @@
" pae, max_pae = list(pae_outputs.values())[0]\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n",
"\n",
" # Display lines at chain boundaries.\n",
" best_unrelaxed_prot = unrelaxed_proteins[best_model_name]\n",
" total_num_res = best_unrelaxed_prot.residue_index.shape[-1]\n",
" chain_ids = best_unrelaxed_prot.chain_index\n",
" for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
" if chain_boundary.size:\n",
" plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
" plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
" plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n",
......@@ -680,7 +814,7 @@
"source": [
"### Interpreting the prediction\n",
"\n",
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions."
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions. More information about the predictions of the AlphaFold Multimer model can be found in the [Alphafold Multimer paper](https://www.biorxiv.org/content/10.1101/2022.03.11.484043v3.full.pdf)."
]
},
{
......@@ -718,7 +852,7 @@
" * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n",
" * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs.\n",
"* Does this tool install anything on my computer?\n",
" * No, everything happens in the cloud on Google Colab.\n",
" * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n",
......@@ -766,13 +900,10 @@
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
......@@ -780,7 +911,8 @@
},
"language_info": {
"name": "python"
}
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
......
......@@ -3,15 +3,15 @@ channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- openmm=7.7
- pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
......@@ -21,14 +21,11 @@ import dataclasses
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np
import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray]
......@@ -704,10 +701,10 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
msa_data = {}
if(alignment_index is not None):
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
......@@ -718,14 +715,14 @@ class DataPipeline:
for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name)
if(ext == ".a3m"):
if ext == ".a3m":
msa = parsers.parse_a3m(
read_msa(start, size)
)
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
# Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size))
else:
continue
......@@ -734,13 +731,22 @@ class DataPipeline:
fp.close()
else:
# Now will split the following steps into multiple processes
current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip()
msa_data = pickle.load((open(msa_data_path,'rb')))
os.remove(msa_data_path)
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if ext == ".a3m":
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data
......
......@@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}
......
......@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
......@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
......@@ -123,7 +123,7 @@ def parse_fasta(data):
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs
......
......@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
local_alignment_dir = os.path.join(alignment_dir, tag)
if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...")
......
......@@ -113,10 +113,10 @@ else:
setup(
name='openfold',
version='1.0.1',
version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
author='OpenFold Team',
author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
......
import ml_collections as mlc
consts = mlc.ConfigDict(
monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
multimer_consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
......@@ -24,6 +49,8 @@ consts = mlc.ConfigDict(
}
)
consts = monomer_consts
config = mlc.ConfigDict(
{
"data": {
......
......@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
......@@ -309,7 +306,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
......
......@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
......@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2])
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5])
self.assertIn(int(anchor_pred_asym), [1, 2])
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self):
batch = {
......@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
......
......@@ -235,6 +235,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment