"vscode:/vscode.git/clone" did not exist on "94377efa785870da02a5077024f9934147546dd2"
Commit 502d7a19 authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Add cell execution order error checking.

PiperOrigin-RevId: 508360288
Change-Id: I1d976ac08808e4356f89c29428668c1093816056
parent 665ebc30
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Helper methods for the AlphaFold Colab notebook.""" """Helper methods for the AlphaFold Colab notebook."""
import json import json
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import AbstractSet, Any, Mapping, Optional, Sequence
from alphafold.common import residue_constants from alphafold.common import residue_constants
from alphafold.data import parsers from alphafold.data import parsers
...@@ -166,3 +166,24 @@ def get_pae_json(pae: np.ndarray, max_pae: float) -> str: ...@@ -166,3 +166,24 @@ def get_pae_json(pae: np.ndarray, max_pae: float) -> str:
'max_predicted_aligned_error': max_pae 'max_predicted_aligned_error': max_pae
}] }]
return json.dumps(formatted_output, indent=None, separators=(',', ':')) return json.dumps(formatted_output, indent=None, separators=(',', ':'))
def check_cell_execution_order(
cells_ran: AbstractSet[int], cell_number: int) -> None:
"""Check that the cell execution order is correct.
Args:
cells_ran: Set of cell numbers that have been executed.
cell_number: The number of the cell that this check is called in.
Raises:
If <1:cell_number> cells haven't been executed, raise error.
"""
previous_cells = set(range(1, cell_number))
cells_not_ran = previous_cells - cells_ran
if cells_not_ran != set():
cells_not_ran_str = ', '.join([str(x) for x in sorted(cells_not_ran)])
raise ValueError(
f'You did not execute the cells: {cells_not_ran_str}. Your Colab '
'runtime may have died during execution. Please restart the runtime '
'and run from the first cell!')
...@@ -191,6 +191,18 @@ class NotebookUtilsTest(parameterized.TestCase): ...@@ -191,6 +191,18 @@ class NotebookUtilsTest(parameterized.TestCase):
pae_json, '[{"predicted_aligned_error":[[0.0,13.1],[20.1,0.0]],' pae_json, '[{"predicted_aligned_error":[[0.0,13.1],[20.1,0.0]],'
'"max_predicted_aligned_error":31.75}]') '"max_predicted_aligned_error":31.75}]')
def test_check_cell_execution_order_correct(self):
notebook_utils.check_cell_execution_order({1, 2}, 3)
@parameterized.named_parameters(
('One missing', 4, {1, 2}, '3'),
('Two missing', 5, {1, 2}, '3, 4'),
)
def test_check_cell_execution_order_missing(
self, cell_num, cells_ran, cells_missing):
with self.assertRaisesRegex(ValueError, f'.+{cells_missing}'):
notebook_utils.check_cell_execution_order(cells_ran, cell_num)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()
...@@ -60,7 +60,6 @@ ...@@ -60,7 +60,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "woIxeCPygt7K" "id": "woIxeCPygt7K"
}, },
"outputs": [], "outputs": [],
...@@ -126,14 +125,15 @@ ...@@ -126,14 +125,15 @@
" pbar.update(1)\n", " pbar.update(1)\n",
"except subprocess.CalledProcessError:\n", "except subprocess.CalledProcessError:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise\n",
"\n",
"executed_cells = set([1])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw" "id": "VzJ5iMjTtoZw"
}, },
"outputs": [], "outputs": [],
...@@ -195,7 +195,9 @@ ...@@ -195,7 +195,9 @@
"# Make sure everything we need is on the path.\n", "# Make sure everything we need is on the path.\n",
"import sys\n", "import sys\n",
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n", "sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
"sys.path.append('/content/alphafold')" "sys.path.append('/content/alphafold')\n",
"\n",
"executed_cells.add(2)"
] ]
}, },
{ {
...@@ -215,7 +217,6 @@ ...@@ -215,7 +217,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n" "id": "rowN0bVYLe9n"
}, },
"outputs": [], "outputs": [],
...@@ -227,6 +228,9 @@ ...@@ -227,6 +228,9 @@
"#@markdown * If you enter multiple sequences, the multimer model will be used.\n", "#@markdown * If you enter multiple sequences, the multimer model will be used.\n",
"\n", "\n",
"from alphafold.notebooks import notebook_utils\n", "from alphafold.notebooks import notebook_utils\n",
"# Track cell execution to ensure correct order.\n",
"notebook_utils.check_cell_execution_order(executed_cells, 3)\n",
"\n",
"import enum\n", "import enum\n",
"\n", "\n",
"@enum.unique\n", "@enum.unique\n",
...@@ -316,14 +320,15 @@ ...@@ -316,14 +320,15 @@
" print('WARNING: The accuracy of the system has not been fully validated '\n", " print('WARNING: The accuracy of the system has not been fully validated '\n",
" 'above 3000 residues, and you may experience long running times or '\n", " 'above 3000 residues, and you may experience long running times or '\n",
" f'run out of memory. Total sequence length is {total_sequence_length} '\n", " f'run out of memory. Total sequence length is {total_sequence_length} '\n",
" 'residues.')\n" " 'residues.')\n",
"\n",
"executed_cells.add(3)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB" "id": "2tTeTTsLKPjB"
}, },
"outputs": [], "outputs": [],
...@@ -336,6 +341,9 @@ ...@@ -336,6 +341,9 @@
"#@markdown you’ll see how well each residue is covered by similar\n", "#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n", "#@markdown sequences in the MSA.\n",
"\n", "\n",
"# Track cell execution to ensure correct order\n",
"notebook_utils.check_cell_execution_order(executed_cells, 4)\n",
"\n",
"# --- Python imports ---\n", "# --- Python imports ---\n",
"import collections\n", "import collections\n",
"import copy\n", "import copy\n",
...@@ -531,14 +539,15 @@ ...@@ -531,14 +539,15 @@
" all_chain_features=all_chain_features)\n", " all_chain_features=all_chain_features)\n",
"\n", "\n",
" # Pad MSA to avoid zero-sized extra_msa.\n", " # Pad MSA to avoid zero-sized extra_msa.\n",
" np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)" " np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)\n",
"\n",
"executed_cells.add(4)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2" "id": "XUo6foMQxwS2"
}, },
"outputs": [], "outputs": [],
...@@ -568,6 +577,8 @@ ...@@ -568,6 +577,8 @@
"\n", "\n",
"multimer_model_max_num_recycles = 3 #@param {type:\"integer\"}\n", "multimer_model_max_num_recycles = 3 #@param {type:\"integer\"}\n",
"\n", "\n",
"# Track cell execution to ensure correct order\n",
"notebook_utils.check_cell_execution_order(executed_cells, 5)\n",
"\n", "\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"if model_type_to_use == ModelType.MONOMER:\n", "if model_type_to_use == ModelType.MONOMER:\n",
...@@ -775,7 +786,9 @@ ...@@ -775,7 +786,9 @@
"\n", "\n",
"# --- Download the predictions ---\n", "# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n", "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')" "files.download(f'{output_dir}.zip')\n",
"\n",
"executed_cells.add(5)"
] ]
}, },
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment