"launch/dynemo-run/src/opt.rs" did not exist on "2f700421cb168e9550693a8d0a76b688bf0c7967"
Commit 94a3b18e authored by jnwei's avatar jnwei Committed by Jennifer Wei
Browse files

Adding multimer support to OpenFold notebook

parent 81ca652e
......@@ -3,8 +3,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
......@@ -50,46 +49,76 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n"
"id": "rowN0bVYLe9n",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputs": [],
"outputId": "8a4c9c7d-a555-460f-de39-d6aff2272d36"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Length of input sequence : 716\n"
]
}
],
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@markdown For multiple sequences, separate sequences with a colon `:`\n",
"input_sequence = 'MVDATRVPMDERFRTLKKKLEEGMVFTEYEQIPKKKANGIFSTAALPENAERSRIREVVPYEENRVELIPTKENNTGYINASHIKVVVGGAEWHYIATQGPLPHTCHDFWQMVWEQGVNVIAMVTAEEEGGRTKSHRYWPKLGSKHSSATYGKFKVTTKFRTDSVCYATTGLKVKHLLSGQERTVWHLQYTDWPDHGCPEDVQGFLSYLEEIQSVRRHTNSMLEGTKNRHPPIVVHCSAGVGRTGVLILSELMIYCLEHNEKVEVPMMLRLLREQRMFMIQTIAQYKFVYQVLIQFLQNSRLI:MVDATRVPMDERFRTLKKKLEEGMVFTEYEQIPKKKANGIFSTAALPENAERSRIREVVPYEENRVELIPTKENNTGYINASHIKVVVGGAEWHYIATQGPLPHTCHDFWQMVWEQGVNVIAMVTAEEEGGRTKSHRYWPKLGSKHSSATYGKFKVTTKFRTDSVCYATTGLKVKHLLSGQERTVWHLQYTDWPDHGCPEDVQGFLSYLEEIQSVRRHTNSMLEGTKNRHPPIVVHCSAGVGRTGVLILSELMIYCLEHNEKVEVPMMLRLLREQRMFMIQTIAQYKFVYQVLIQFLQNSRLI:GHMAEPQRHTMLCMCCKCEARIELVVESSADDLRAFQQLFLNTLSFVCPWCASQQ:GHMAEPQRHTMLCMCCKCEARIELVVESSADDLRAFQQLFLNTLSFVCPWCASQQ' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"weight_set = 'AlphaFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"model_mode = 'multimer' #@param [\"monomer\", \"multimer\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"input_sequence = input_sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"allowed_chars = aatypes.union({':'})\n",
"if not set(input_sequence).issubset(allowed_chars):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(input_sequence) - allowed_chars}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"if ':' in input_sequence and weight_set != 'AlphaFold':\n",
" raise ValueError('Input sequence is a multimer, must select Alphafold weight set')\n",
"\n",
"import enum\n",
"@enum.unique\n",
"class ModelType(enum.Enum):\n",
" MONOMER = 0\n",
" MULTIMER = 1\n",
"\n",
"model_type_dict = {\n",
" 'monomer': ModelType.MONOMER,\n",
" 'multimer': ModelType.MULTIMER,\n",
"}\n",
"\n",
"model_type = model_type_dict[model_mode]\n",
"print(f'Length of input sequence : {len(input_sequence.replace(\":\", \"\"))}')\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"cellView": "form",
"id": "woIxeCPygt7K"
},
"outputs": [],
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"\n",
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown **Note**: This installs the software on the Colab\n",
"#@markdown notebook in the cloud and not on your computer.\n",
"\n",
"import os, time\n",
......@@ -103,10 +132,8 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n",
"\n",
"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
......@@ -119,12 +146,12 @@
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n",
" %shell mkdir -p /content/openfold/openfold/resourcees\n",
" \n",
" commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
"\n",
" commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" \n",
" %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n",
"\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
"\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)"
......@@ -132,20 +159,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw"
},
"outputs": [],
"source": [
"#@title Download model weights \n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@title Download model weights\n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
......@@ -182,19 +208,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"outputs": [],
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n",
"\n",
"import unittest.mock\n",
"import sys\n",
"from typing import Dict, Sequence\n",
"\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
......@@ -234,22 +260,12 @@
" return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data import msa_pairing\n",
"from openfold.data import feature_processing_multimer\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
......@@ -276,22 +292,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
"source": [
"#@title Search against genetic databases\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown statistics about the multiple sequence alignment\n",
"#@markdown (MSA) that will be used by OpenFold. In particular,\n",
"#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n",
"\n",
"# --- Find the closest source ---\n",
"# --- Find the closest source --\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n",
"def fetch(source):\n",
......@@ -304,114 +314,156 @@
" ex.shutdown()\n",
" break\n",
"\n",
"# --- Search against genetic databases ---\n",
"with open('target.fasta', 'wt') as f:\n",
" f.write(f'>query\\n{sequence}')\n",
"\n",
"# Run the search against chunks of genetic databases (since the genetic\n",
"# databases don't fit in Colab ramdisk).\n",
"\n",
"jackhmmer_binary_path = '/usr/bin/jackhmmer'\n",
"dbs = []\n",
"\n",
"num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n",
"total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n",
"# --- Parse multiple sequences, if there are any ---\n",
"def split_multiple_sequences(sequence):\n",
" seqs = sequence.split(':')\n",
" sorted_seqs = sorted(seqs, key=lambda s: len(s))\n",
"\n",
" # TODO: Handle the homomer case when writing fasta sequences\n",
" fasta_path_tuples = []\n",
" for idx, seq in enumerate(set(sorted_seqs)):\n",
" fasta_path = f'target_{idx+1}.fasta'\n",
" with open(fasta_path, 'wt') as f:\n",
" f.write(f'>query\\n{seq}\\n')\n",
" fasta_path_tuples.append((seq, fasta_path))\n",
" fasta_path_by_seq = dict(fasta_path_tuples)\n",
"\n",
" return sorted_seqs, fasta_path_by_seq\n",
"\n",
"sequences, fasta_path_by_sequence = split_multiple_sequences(input_sequence)\n",
"db_results_by_sequence = {seq: {} for seq in fasta_path_by_sequence.keys()}\n",
"\n",
"DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'\n",
"db_configs = {}\n",
"db_configs['smallbfd'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2021_03.fasta',\n",
" 'z_value': 65984053,\n",
" 'num_jackhmmer_chunks': 17,\n",
"}\n",
"db_configs['mgnify'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}mgy_clusters_2022_05.fasta',\n",
" 'z_value': 304820129,\n",
" 'num_jackhmmer_chunks': 120,\n",
"}\n",
"db_configs['uniref90'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2022_01.fasta',\n",
" 'z_value': 144113457,\n",
" 'num_jackhmmer_chunks': 62,\n",
"}\n",
"\n",
"# Search UniProt and construct the all_seq features only for heteromers, not homomers.\n",
"if model_type == ModelType.MULTIMER and len(set(sequences)) > 1:\n",
" db_configs['uniprot'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniprot_2021_04.fasta',\n",
" 'z_value': 225013025 + 565928,\n",
" 'num_jackhmmer_chunks': 101,\n",
" }\n",
"\n",
"total_jackhmmer_chunks = sum([d['num_jackhmmer_chunks'] for d in db_configs.values()])\n",
"with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" def jackhmmer_chunk_callback(i):\n",
" pbar.update(n=1)\n",
"\n",
" pbar.set_description('Searching uniref90')\n",
" jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n",
" for db_name, db_config in db_configs.items():\n",
" pbar.set_description(f'Searching {db_name}')\n",
" jackhmmer_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n",
" database_path=db_config['database_path'],\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n",
" num_streamed_chunks=db_config['num_jackhmmer_chunks'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=135301051)\n",
" dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n",
" z_value=db_config['z_value'])\n",
"\n",
" pbar.set_description('Searching smallbfd')\n",
" jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=65984053)\n",
" dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching mgnify')\n",
" jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=304820129)\n",
" dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n",
" db_results = jackhmmer_runner.query_multiple(fasta_path_by_sequence.values())\n",
" for seq, result in zip(fasta_path_by_sequence.keys(), db_results):\n",
" db_results_by_sequence[seq][db_name] = result\n",
"\n",
"\n",
"# --- Extract the MSAs and visualize ---\n",
"# Extract the MSAs from the Stockholm files.\n",
"# NB: deduplication happens later in data_pipeline.make_msa_features.\n",
"\n",
"mgnify_max_hits = 501\n",
"MAX_HITS_BY_DB = {\n",
" 'uniref90': 10000,\n",
" 'smallbfd': 5000,\n",
" 'mgnify': 501,\n",
" 'uniprot': 50000,\n",
"}\n",
"\n",
"msas_by_seq_by_db = {seq: {} for seq in sequences}\n",
"full_msa_by_seq = {seq: [] for seq in sequences}\n",
"\n",
"msas = []\n",
"deletion_matrices = []\n",
"full_msa = []\n",
"for db_name, db_results in dbs:\n",
"for seq, sequence_result in db_results_by_sequence.items():\n",
" print(f'parsing_results_for_sequence {seq}')\n",
" for db_name, db_results in sequence_result.items():\n",
" unsorted_results = []\n",
" for i, result in enumerate(db_results):\n",
" msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n",
" msa_obj = parsers.parse_stockholm(result['sto'])\n",
" e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n",
" target_names = msa_obj.descriptions\n",
" e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n",
" zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n",
" zipped_results = zip(msa_obj.sequences, msa_obj.deletion_matrix, target_names, e_values)\n",
" if i != 0:\n",
" # Only take query from the first chunk\n",
" zipped_results = [x for x in zipped_results if x[2] != 'query']\n",
" unsorted_results.extend(zipped_results)\n",
" sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n",
" db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n",
" msas, del_matrix, targets, _ = zip(*sorted_by_evalue)\n",
" db_msas = parsers.Msa(msas, del_matrix, targets)\n",
" if db_msas:\n",
" if db_name == 'mgnify':\n",
" db_msas = db_msas[:mgnify_max_hits]\n",
" db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n",
" full_msa.extend(db_msas)\n",
" msas.append(db_msas)\n",
" deletion_matrices.append(db_deletion_matrices)\n",
" msa_size = len(set(db_msas))\n",
" if db_name in MAX_HITS_BY_DB:\n",
" db_msas.truncate(MAX_HITS_BY_DB[db_name])\n",
" msas_by_seq_by_db[seq][db_name] = db_msas\n",
" full_msa_by_seq[seq].extend(db_msas.sequences)\n",
" msa_size = len(set(db_msas.sequences))\n",
" print(f'{msa_size} Sequences Found in {db_name}')\n",
"\n",
"deduped_full_msa = list(dict.fromkeys(full_msa))\n",
"total_msa_size = len(deduped_full_msa)\n",
"print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
"aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
"msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
"num_alignments, num_res = msa_arr.shape\n",
"\n",
"fig = plt.figure(figsize=(12, 3))\n",
"max_num_alignments = 0\n",
"\n",
"for seq_idx, seq in enumerate(set(sequences)):\n",
" full_msas = full_msa_by_seq[seq]\n",
" deduped_full_msa = list(dict.fromkeys(full_msas))\n",
" total_msa_size = len(deduped_full_msa)\n",
" print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
" aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
" msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
" num_alignments, num_res = msa_arr.shape\n",
" plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), label=f'Chain {seq_idx}')\n",
" max_num_alignments = max(num_alignments, max_num_alignments)\n",
"\n",
"\n",
"plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n",
"plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n",
"plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.yticks(range(0, max_num_alignments + 1, max(1, int(max_num_alignments / 3))))\n",
"plt.legend()\n",
"plt.show()"
]
],
"metadata": {
"id": "o7BqQN_gfYtq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2"
},
"outputs": [],
"source": [
"#@title Run OpenFold and download prediction\n",
"\n",
"#@markdown Once this cell has been executed, a zip-archive with \n",
"#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown Once this cell has been executed, a zip-archive with\n",
"#@markdown the obtained prediction will be automatically downloaded\n",
"#@markdown to your computer.\n",
"\n",
"# Color bands for visualizing plddt\n",
......@@ -423,13 +475,22 @@
"]\n",
"\n",
"# --- Run the model ---\n",
"model_names = [ \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
"if model_type == ModelType.MONOMER:\n",
" model_names = [\n",
" 'finetuning_3.pt',\n",
" 'finetuning_4.pt',\n",
" 'finetuning_5.pt',\n",
" 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n",
" ]\n",
"elif model_type == ModelType.MULTIMER:\n",
" model_names = [\n",
" 'model_1_multimer_v3',\n",
" 'model_2_multimer_v3',\n",
" 'model_3_multimer_v3',\n",
" 'model_4_multimer_v3',\n",
" 'model_5_multimer_v3',\n",
" ]\n",
"\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n",
......@@ -440,26 +501,72 @@
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n",
"\n",
"\n",
"def make_features(\n",
" sequences: Sequence[str],\n",
" msas_by_seq_by_db: Dict[str, Dict[str, parsers.Msa]],\n",
" model_type: ModelType):\n",
" num_templates = 1 # Placeholder for generating fake template features\n",
" feature_dict = {}\n",
"\n",
" for idx, seq in enumerate(sequences, start=1):\n",
" _chain_id = f'chain_{idx}'\n",
" num_res = len(seq)\n",
"\n",
" feats = data_pipeline.make_sequence_features(seq, _chain_id, num_res)\n",
" msas_without_uniprot = [msas_by_seq_by_db[seq][db] for db in db_configs.keys() if db != 'uniprot']\n",
" msa_feats = data_pipeline.make_msa_features(msas_without_uniprot)\n",
" feats.update(msa_feats)\n",
" feats.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" feature_dict[seq] = feats\n",
" if model_type == ModelType.MULTIMER:\n",
" # Perform extra pair processing steps for heteromers\n",
" if len(set(sequences)) > 1:\n",
" uniprot_msa = msas_by_seq_by_db[seq]['uniprot']\n",
" uniprot_msa_features = data_pipeline.make_msa_features([uniprot_msa])\n",
" valid_feat_names = msa_pairing.MSA_FEATURES + (\n",
" 'msa_species_identifiers',\n",
" )\n",
" pair_feats = {\n",
" f'{k}_all_seq': v for k, v in uniprot_msa_features.items()\n",
" if k in valid_feat_names\n",
" }\n",
" feats.update(pair_feats)\n",
"\n",
" feats = data_pipeline.convert_monomer_features(feats, _chain_id)\n",
" feature_dict[_chain_id] = feats\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" np_example = feature_dict[sequences[0]]\n",
" elif model_type == ModelType.MULTIMER:\n",
" all_chain_feats = data_pipeline.add_assembly_features(feature_dict)\n",
" features = feature_processing_multimer.pair_and_merge(all_chain_features=all_chain_feats)\n",
" np_example = data_pipeline.pad_msa(features, 512)\n",
"\n",
" return np_example\n",
"\n",
"\n",
"output_dir = 'prediction'\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"plddts = {}\n",
"pae_outputs = {}\n",
"weighted_ptms = {}\n",
"unrelaxed_proteins = {}\n",
"\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in list(enumerate(model_names)):\n",
"with tqdm.notebook.tqdm(total=len(model_names), bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in enumerate(model_names, start = 1):\n",
" pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" feature_dict = make_features(sequences, msas_by_seq_by_db, model_type)\n",
"\n",
" if(weight_set == \"AlphaFold\"):\n",
" if model_type == ModelType.MONOMER:\n",
" config_preset = f\"model_{i}\"\n",
" elif model_type == ModelType.MULTIMER:\n",
" config_preset = f'model_{i}_multimer_v3'\n",
" else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
......@@ -469,6 +576,11 @@
" config_preset += \"_ptm\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
"\n",
" # Force the model to only use 3 recycling updates\n",
" cfg.data.common.max_recycling_iters = 3\n",
" cfg.model.recycle_early_stop_tolerance = -1\n",
"\n",
" openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n",
......@@ -490,7 +602,9 @@
"\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n",
" feature_dict, mode='predict'\n",
" feature_dict,\n",
" mode='predict',\n",
" is_multimer = (model_type == ModelType.MULTIMER),\n",
" )\n",
"\n",
" processed_feature_dict = tensor_tree_map(\n",
......@@ -510,6 +624,7 @@
"\n",
" mean_plddt = prediction_result['plddt'].mean()\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" if 'predicted_aligned_error' in prediction_result:\n",
" pae_outputs[model_name] = (\n",
" prediction_result['predicted_aligned_error'],\n",
......@@ -519,12 +634,22 @@
" # Get the pLDDT confidence metrics. Do not put pTM models here as they\n",
" # should never get selected.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" elif model_type == ModelType.MULTIMER:\n",
" # Multimer models are sorted by pTM+ipTM.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" pae_outputs[model_name] = (prediction_result['predicted_aligned_error'],\n",
" prediction_result['max_predicted_aligned_error'])\n",
"\n",
" weighted_ptms[model_name] = prediction_result['weighted_ptm_score']\n",
"\n",
" # Set the b-factors to the per-residue plddt.\n",
" final_atom_mask = prediction_result['final_atom_mask']\n",
" b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n",
" unrelaxed_protein = protein.from_prediction(\n",
" processed_feature_dict, prediction_result, b_factors=b_factors\n",
" processed_feature_dict,\n",
" prediction_result,\n",
" remove_leading_feature_dimension=False,\n",
" b_factors=b_factors,\n",
" )\n",
" unrelaxed_proteins[model_name] = unrelaxed_protein\n",
"\n",
......@@ -535,7 +660,10 @@
" pbar.update(n=1)\n",
"\n",
" # Find the best model according to the mean pLDDT.\n",
" if model_type == ModelType.MONOMER:\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" elif model_type == ModelType.MULTIMER:\n",
" best_model_name = max(weighted_ptms.keys(), key=lambda x: weighted_ptms[x])\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n",
......@@ -547,7 +675,7 @@
" stiffness=10.0,\n",
" exclude_residues=[],\n",
" max_outer_iterations=20,\n",
" use_gpu=False,\n",
" use_gpu=True,\n",
" )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n",
......@@ -598,6 +726,15 @@
" plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n",
"\n",
"# Show the structure coloured by chain if the multimer model has been used.\n",
"if model_type == ModelType.MULTIMER:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n",
......@@ -643,6 +780,15 @@
" pae, max_pae = list(pae_outputs.values())[0]\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n",
"\n",
" # Display lines at chain boundaries.\n",
" best_unrelaxed_prot = unrelaxed_proteins[best_model_name]\n",
" total_num_res = best_unrelaxed_prot.residue_index.shape[-1]\n",
" chain_ids = best_unrelaxed_prot.chain_index\n",
" for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
" if chain_boundary.size:\n",
" plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
" plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
" plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n",
......@@ -680,7 +826,7 @@
"source": [
"### Interpreting the prediction\n",
"\n",
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions."
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions. More information about the predictions of the AlphaFold Multimer model can be found in the [Alphafold Multimer paper](https://www.biorxiv.org/content/10.1101/2022.03.11.484043v3.full.pdf)."
]
},
{
......@@ -718,7 +864,7 @@
" * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n",
" * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs.\n",
"* Does this tool install anything on my computer?\n",
" * No, everything happens in the cloud on Google Colab.\n",
" * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n",
......@@ -766,13 +912,9 @@
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
"gpuType": "T4"
},
"kernelspec": {
"display_name": "Python 3",
......@@ -780,7 +922,8 @@
},
"language_info": {
"name": "python"
}
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
......
......@@ -3,15 +3,15 @@ channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- openmm=7.7
- pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment