Commit bbb9c9fd authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Revamp Colab

parent dc4d04c8
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {
"name": "OpenFold.ipynb", "name": "OpenFold (2).ipynb",
"provenance": [], "provenance": [],
"collapsed_sections": [] "collapsed_sections": []
}, },
...@@ -55,6 +55,33 @@ ...@@ -55,6 +55,33 @@
"FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." "FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
] ]
}, },
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
],
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
...@@ -63,10 +90,9 @@ ...@@ -63,10 +90,9 @@
}, },
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n", "\n",
"#@markdown Please execute this cell by pressing the _Play_ button \n",
"#@markdown on the left to download and import third-party software \n",
"#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in DeepMind's README.)\n",
"\n", "\n",
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
...@@ -79,43 +105,45 @@ ...@@ -79,43 +105,45 @@
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n", "\n",
"try:\n", "try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" # Uninstall default Colab version of PyTorch.\n",
" # %shell pip uninstall -y torch\n",
"\n",
" %shell sudo apt install --quiet --yes hmmer\n", " %shell sudo apt install --quiet --yes hmmer\n",
" pbar.update(6)\n",
"\n", "\n",
" # Install py3dmol.\n", " # Install py3dmol.\n",
" %shell pip install py3dmol\n", " %shell pip install py3dmol\n",
" pbar.update(2)\n",
"\n", "\n",
" # Install OpenMM and pdbfixer.\n",
" %shell rm -rf /opt/conda\n", " %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n", " %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n", " https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n", " && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n", " && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
" pbar.update(9)\n",
"\n", "\n",
" PATH=%env PATH\n", " PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n", " %env PATH=/opt/conda/bin:{PATH}\n",
" pbar.update(80)\n", "\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python=3.7 \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79\n",
"\n", "\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
" pbar.update(2)\n",
"\n", "\n",
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
" pbar.update(1)\n",
"\n",
" # Install git-lfs\n",
" %shell sudo apt-get install git-lfs\n",
" %shell git lfs install\n",
"\n", "\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise"
...@@ -130,14 +158,12 @@ ...@@ -130,14 +158,12 @@
"cellView": "form" "cellView": "form"
}, },
"source": [ "source": [
"#@title Download OpenFold\n", "#@title Install OpenFold\n",
"\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n", "# Define constants\n",
"\n", "GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"OPENFOLD_PARAM_SOURCE_URL = \"https://huggingface.co/nz/OpenFold\"\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n", "ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n", "OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n", "ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
...@@ -146,38 +172,40 @@ ...@@ -146,38 +172,40 @@
")\n", ")\n",
"\n", "\n",
"try:\n", "try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n", " %shell rm -rf openfold\n",
" %shell git clone {GIT_REPO} openfold\n", " %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" pbar.update(8)\n",
" # Install the required versions of all dependencies.\n",
" %shell conda env update -n base --file openfold/environment.yml\n",
" \n",
" %shell mkdir -p /content/openfold/openfold/resources\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n", " %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" \n", " %shell /usr/bin/python3 -m pip install -q ./openfold\n",
" # Run setup.py to install only Openfold.\n",
" %shell pip3 install --no-dependencies ./openfold\n",
" pbar.update(10)\n",
"\n", "\n",
" if(relax_prediction):\n",
" %shell conda install -y -q -c conda-forge \\\n",
" openmm=7.5.1 \\\n",
" pdbfixer=1.7\n",
" \n",
" # Apply OpenMM patch.\n", " # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n", " %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n", " patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n", " popd\n",
"\n", "\n",
" if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n", " %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" pbar.update(27)\n",
"\n",
" %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n", " %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n", " --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n", " %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
"\n", " elif(weight_set == 'OpenFold'):\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n", " %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" %shell git clone \"{OPENFOLD_PARAM_SOURCE_URL}\" \"{OPENFOLD_PARAMS_DIR}\"\n", " %shell aws s3 cp \\\n",
" pbar.update(55)\n", " --no-sign-request \\\n",
"except subprocess.CalledProcessError:\n", " --region us-east-1 \\\n",
" s3://openfold/openfold_params \"{OPENFOLD_PARAMS_DIR}\" \\\n",
" --recursive\n",
" else:\n",
" raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise"
], ],
...@@ -191,8 +219,23 @@ ...@@ -191,8 +219,23 @@
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"import unittest.mock\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n", "sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"\n",
"# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n",
" \"dllogger\",\n",
" \"pytorch_lightning\",\n",
" \"pytorch_lightning.utilities\",\n",
" \"pytorch_lightning.callbacks.early_stopping\",\n",
" \"pytorch_lightning.utilities.seed\",\n",
"]\n",
"for unnecessary_module in unnecessary_modules:\n",
" sys.modules[unnecessary_module] = unittest.mock.MagicMock()\n",
"\n",
"import os\n", "import os\n",
"\n", "\n",
"from urllib import request\n", "from urllib import request\n",
...@@ -224,8 +267,9 @@ ...@@ -224,8 +267,9 @@
"from openfold.data.tools import jackhmmer\n", "from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n", "from openfold.model import model\n",
"from openfold.np import protein\n", "from openfold.np import protein\n",
"from openfold.np.relax import relax\n", "if(relax_prediction):\n",
"from openfold.np.relax import utils\n", " from openfold.np.relax import relax\n",
" from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n", "from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n", "from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n", "\n",
...@@ -234,8 +278,8 @@ ...@@ -234,8 +278,8 @@
"from ipywidgets import Output" "from ipywidgets import Output"
], ],
"metadata": { "metadata": {
"cellView": "form", "id": "_FpxxMo-mvcP",
"id": "_FpxxMo-mvcP" "cellView": "form"
}, },
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -248,41 +292,9 @@ ...@@ -248,41 +292,9 @@
"source": [ "source": [
"## Making a prediction\n", "## Making a prediction\n",
"\n", "\n",
"Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ > _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.\n",
"\n",
"Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)." "Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
] ]
}, },
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"MIN_SEQUENCE_LENGTH = 16\n",
"MAX_SEQUENCE_LENGTH = 2500\n",
"\n",
"#@markdown ### Choose between OpenFold and AlphaFold model parameters ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')\n",
"if len(sequence) < MIN_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')\n",
"if len(sequence) > MAX_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')"
],
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
...@@ -298,12 +310,6 @@ ...@@ -298,12 +310,6 @@
"#@markdown you’ll see how well each residue is covered by similar \n", "#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n", "#@markdown sequences in the MSA.\n",
"\n", "\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')]\n",
"\n",
"# --- Find the closest source ---\n", "# --- Find the closest source ---\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n", "ex = futures.ThreadPoolExecutor(3)\n",
...@@ -427,22 +433,30 @@ ...@@ -427,22 +433,30 @@
"#@markdown the obtained prediction will be automatically downloaded \n", "#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [\n",
" (0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')\n",
"]\n",
"\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = [\n", "model_names = [ \n",
" 'finetuning_2.pt', \n",
" 'finetuning_3.pt', \n", " 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n", " 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n", " 'finetuning_5.pt', \n",
" 'finetuning_ptm_2.pt'\n", " 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n", "]\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
" 'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),\n", " 'template_aatype': np.zeros((num_templates_, num_res_, 22), dtype=np.int64),\n",
" 'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),\n", " 'template_all_atom_positions': np.zeros((num_templates_, num_res_, 37, 3), dtype=np.float32),\n",
" 'template_all_atom_mask': torch.zeros(num_templates_, num_res_, 37),\n", " 'template_all_atom_mask': np.zeros((num_templates_, num_res_, 37), dtype=np.float32),\n",
" 'template_domain_names': torch.zeros(num_templates_),\n", " 'template_domain_names': np.zeros((num_templates_,), dtype=np.float32),\n",
" 'template_sum_probs': torch.zeros(num_templates_, 1),\n", " 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n", " }\n",
"\n", "\n",
"output_dir = 'prediction'\n", "output_dir = 'prediction'\n",
...@@ -453,11 +467,11 @@ ...@@ -453,11 +467,11 @@
"unrelaxed_proteins = {}\n", "unrelaxed_proteins = {}\n",
"\n", "\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in enumerate(model_names):\n", " for i, model_name in list(enumerate(model_names)):\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n", " num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n", " num_res = len(sequence)\n",
"\n", " \n",
" feature_dict = {}\n", " feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n", " feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n", " feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
...@@ -466,6 +480,9 @@ ...@@ -466,6 +480,9 @@
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n", " config_preset = f\"model_{i}\"\n",
" else:\n", " else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
" else:\n",
" config_preset = \"model_1\"\n", " config_preset = \"model_1\"\n",
" if(\"_ptm_\" in model_name):\n", " if(\"_ptm_\" in model_name):\n",
" config_preset += \"_ptm\"\n", " config_preset += \"_ptm\"\n",
...@@ -536,7 +553,12 @@ ...@@ -536,7 +553,12 @@
" del prediction_result\n", " del prediction_result\n",
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n", " # --- AMBER relax the best model ---\n",
" if(relax_prediction):\n",
" pbar.set_description(f'AMBER relaxation')\n", " pbar.set_description(f'AMBER relaxation')\n",
" amber_relaxer = relax.AmberRelaxation(\n", " amber_relaxer = relax.AmberRelaxation(\n",
" max_iterations=0,\n", " max_iterations=0,\n",
...@@ -544,12 +566,19 @@ ...@@ -544,12 +566,19 @@
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=20,\n", " max_outer_iterations=20,\n",
" use_gpu=True,\n", " use_gpu=False,\n",
" )\n", " )\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n", " relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name])\n", " prot=unrelaxed_proteins[best_model_name]\n",
" )\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
" best_pdb = relaxed_pdb\n",
"\n",
" pbar.update(n=1) # Finished AMBER relax.\n", " pbar.update(n=1) # Finished AMBER relax.\n",
"\n", "\n",
"# Construct multiclass b-factors to indicate confidence bands\n", "# Construct multiclass b-factors to indicate confidence bands\n",
...@@ -561,14 +590,7 @@ ...@@ -561,14 +590,7 @@
" banded_b_factors.append(idx)\n", " banded_b_factors.append(idx)\n",
" break\n", " break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n", "banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n", "to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
"\n",
"\n",
"# Write out the prediction\n",
"pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
"with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
"\n", "\n",
"# --- Visualise the prediction & confidence ---\n", "# --- Visualise the prediction & confidence ---\n",
"show_sidechains = True\n", "show_sidechains = True\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment