Commit 892adf48 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix notebook

parent 4a6c9bf1
...@@ -428,7 +428,13 @@ ...@@ -428,7 +428,13 @@
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_1_ptm']\n", "model_names = [\n",
" 'finetuning_2.pt', \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
" 'finetuning_ptm_2.pt'\n",
"]\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
...@@ -447,7 +453,7 @@ ...@@ -447,7 +453,7 @@
"unrelaxed_proteins = {}\n", "unrelaxed_proteins = {}\n",
"\n", "\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for model_name in model_names:\n", " for i, model_name in enumerate(model_names):\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n", " num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n", " num_res = len(sequence)\n",
...@@ -457,21 +463,23 @@ ...@@ -457,21 +463,23 @@
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n", " feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n", " feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n", "\n",
" cfg = config.model_config(model_name)\n", " if(weight_set == \"AlphaFold\")\n",
" config_preset = f\"model_{i}\"\n",
" else:\n",
" config_preset = \"model_1\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
" openfold_model = model.AlphaFold(cfg)\n", " openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n", " openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
" params_name = os.path.join(ALPHAFOLD_PARAMS_DIR, f\"params_{model_name}.npz\")\n", " params_name = os.path.join(\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n", " ALPHAFOLD_PARAMS_DIR, f\"params_{config_preset}.npz\"\n",
" )\n",
" import_jax_weights_(openfold_model, params_name, version=config_preset)\n",
" elif(weight_set == \"OpenFold\"):\n", " elif(weight_set == \"OpenFold\"):\n",
" model_name_spl = model_name.split(\"_\")\n",
" if(model_name_spl[-1] == \"ptm\"):\n",
" of_model_name = \"finetuning_ptm_2.pt\"\n",
" else:\n",
" of_model_name = f\"finetuning_{model_name_spl[-1]}.pt\"\n",
" params_name = os.path.join(\n", " params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n", " OPENFOLD_PARAMS_DIR,\n",
" of_model_name\n", " model_name,\n",
" )\n", " )\n",
" d = torch.load(params_name)\n", " d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n", " openfold_model.load_state_dict(d)\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment