Commit 502d7a19 authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Add cell execution order error checking.

PiperOrigin-RevId: 508360288
Change-Id: I1d976ac08808e4356f89c29428668c1093816056
parent 665ebc30
......@@ -14,7 +14,7 @@
"""Helper methods for the AlphaFold Colab notebook."""
import json
from typing import Any, Mapping, Optional, Sequence, Tuple
from typing import AbstractSet, Any, Mapping, Optional, Sequence
from alphafold.common import residue_constants
from alphafold.data import parsers
......@@ -166,3 +166,24 @@ def get_pae_json(pae: np.ndarray, max_pae: float) -> str:
'max_predicted_aligned_error': max_pae
}]
return json.dumps(formatted_output, indent=None, separators=(',', ':'))
def check_cell_execution_order(
cells_ran: AbstractSet[int], cell_number: int) -> None:
"""Check that the cell execution order is correct.
Args:
cells_ran: Set of cell numbers that have been executed.
cell_number: The number of the cell that this check is called in.
Raises:
If <1:cell_number> cells haven't been executed, raise error.
"""
previous_cells = set(range(1, cell_number))
cells_not_ran = previous_cells - cells_ran
if cells_not_ran != set():
cells_not_ran_str = ', '.join([str(x) for x in sorted(cells_not_ran)])
raise ValueError(
f'You did not execute the cells: {cells_not_ran_str}. Your Colab '
'runtime may have died during execution. Please restart the runtime '
'and run from the first cell!')
......@@ -191,6 +191,18 @@ class NotebookUtilsTest(parameterized.TestCase):
pae_json, '[{"predicted_aligned_error":[[0.0,13.1],[20.1,0.0]],'
'"max_predicted_aligned_error":31.75}]')
def test_check_cell_execution_order_correct(self):
notebook_utils.check_cell_execution_order({1, 2}, 3)
@parameterized.named_parameters(
('One missing', 4, {1, 2}, '3'),
('Two missing', 5, {1, 2}, '3, 4'),
)
def test_check_cell_execution_order_missing(
self, cell_num, cells_ran, cells_missing):
with self.assertRaisesRegex(ValueError, f'.+{cells_missing}'):
notebook_utils.check_cell_execution_order(cells_ran, cell_num)
if __name__ == '__main__':
absltest.main()
......@@ -60,7 +60,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "woIxeCPygt7K"
},
"outputs": [],
......@@ -126,14 +125,15 @@
" pbar.update(1)\n",
"except subprocess.CalledProcessError:\n",
" print(captured)\n",
" raise"
" raise\n",
"\n",
"executed_cells = set([1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw"
},
"outputs": [],
......@@ -195,7 +195,9 @@
"# Make sure everything we need is on the path.\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
"sys.path.append('/content/alphafold')"
"sys.path.append('/content/alphafold')\n",
"\n",
"executed_cells.add(2)"
]
},
{
......@@ -215,7 +217,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n"
},
"outputs": [],
......@@ -227,6 +228,9 @@
"#@markdown * If you enter multiple sequences, the multimer model will be used.\n",
"\n",
"from alphafold.notebooks import notebook_utils\n",
"# Track cell execution to ensure correct order.\n",
"notebook_utils.check_cell_execution_order(executed_cells, 3)\n",
"\n",
"import enum\n",
"\n",
"@enum.unique\n",
......@@ -316,14 +320,15 @@
" print('WARNING: The accuracy of the system has not been fully validated '\n",
" 'above 3000 residues, and you may experience long running times or '\n",
" f'run out of memory. Total sequence length is {total_sequence_length} '\n",
" 'residues.')\n"
" 'residues.')\n",
"\n",
"executed_cells.add(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
......@@ -336,6 +341,9 @@
"#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n",
"\n",
"# Track cell execution to ensure correct order\n",
"notebook_utils.check_cell_execution_order(executed_cells, 4)\n",
"\n",
"# --- Python imports ---\n",
"import collections\n",
"import copy\n",
......@@ -531,14 +539,15 @@
" all_chain_features=all_chain_features)\n",
"\n",
" # Pad MSA to avoid zero-sized extra_msa.\n",
" np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)"
" np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)\n",
"\n",
"executed_cells.add(4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2"
},
"outputs": [],
......@@ -568,6 +577,8 @@
"\n",
"multimer_model_max_num_recycles = 3 #@param {type:\"integer\"}\n",
"\n",
"# Track cell execution to ensure correct order\n",
"notebook_utils.check_cell_execution_order(executed_cells, 5)\n",
"\n",
"# --- Run the model ---\n",
"if model_type_to_use == ModelType.MONOMER:\n",
......@@ -775,7 +786,9 @@
"\n",
"# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')"
"files.download(f'{output_dir}.zip')\n",
"\n",
"executed_cells.add(5)"
]
},
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment