"...git@developer.sourcefind.cn:OpenDAS/sparseconvnet.git" did not exist on "faa48887471d862bfed66c32850304f6e409e86c"
Commit 9c8f2356 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update notebook

parent 7b29cac4
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
"\n", "\n",
"**Differences to AlphaFold v2.0**\n", "**Differences to AlphaFold v2.0**\n",
"\n", "\n",
"OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling, but not recycling, is excluded from OpenFold).\n", "OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling is excluded from OpenFold (recycling is enabled, however)).\n",
"\n", "\n",
"In this notebook, OpenFold is run with DeepMind's publicly released parameters for AlphaFold 2.\n", "In this notebook, OpenFold is run with DeepMind's publicly released parameters for AlphaFold 2.\n",
"\n", "\n",
...@@ -58,7 +58,8 @@ ...@@ -58,7 +58,8 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "woIxeCPygt7K" "id": "woIxeCPygt7K",
"cellView": "form"
}, },
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
...@@ -125,7 +126,8 @@ ...@@ -125,7 +126,8 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "VzJ5iMjTtoZw" "id": "VzJ5iMjTtoZw",
"cellView": "form"
}, },
"source": [ "source": [
"#@title Download OpenFold\n", "#@title Download OpenFold\n",
...@@ -222,7 +224,8 @@ ...@@ -222,7 +224,8 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "2tTeTTsLKPjB" "id": "2tTeTTsLKPjB",
"cellView": "form"
}, },
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title Search against genetic databases\n",
...@@ -396,13 +399,13 @@ ...@@ -396,13 +399,13 @@
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_2_ptm']\n", "model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_1_ptm']\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
" 'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),\n", " 'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),\n",
" 'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),\n", " 'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),\n",
" 'template_all_atom_masks': torch.zeros(num_templates_, num_res_, 37),\n", " 'template_all_atom_mask': torch.zeros(num_templates_, num_res_, 37),\n",
" 'template_domain_names': torch.zeros(num_templates_),\n", " 'template_domain_names': torch.zeros(num_templates_),\n",
" 'template_sum_probs': torch.zeros(num_templates_, 1),\n", " 'template_sum_probs': torch.zeros(num_templates_, 1),\n",
" }\n", " }\n",
...@@ -417,7 +420,7 @@ ...@@ -417,7 +420,7 @@
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for model_name in model_names:\n", " for model_name in model_names:\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 0\n", " num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n", " num_res = len(sequence)\n",
"\n", "\n",
" feature_dict = {}\n", " feature_dict = {}\n",
...@@ -432,22 +435,14 @@ ...@@ -432,22 +435,14 @@
" import_jax_weights_(openfold_model, params_name, version=model_name)\n", " import_jax_weights_(openfold_model, params_name, version=model_name)\n",
" openfold_model = openfold_model.cuda()\n", " openfold_model = openfold_model.cuda()\n",
"\n", "\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg)\n", " pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n", " processed_feature_dict = pipeline.process_features(\n",
" feature_dict, random_seed=42\n", " feature_dict, mode='predict'\n",
" )\n", " )\n",
"\n", "\n",
" for k, v in processed_feature_dict.items():\n", " processed_feature_dict = tensor_tree_map(\n",
" v = v.permute(*list(range(len(v.shape)))[1:], 0)\n", " lambda t: t.cuda(), processed_feature_dict\n",
" v = v.cuda()\n", " )\n",
" processed_feature_dict[k] = v\n",
"\n",
" for k,v in processed_feature_dict.items():\n",
" if(k == \"template_aatype\"):\n",
" processed_feature_dict[k] = v.long()\n",
"\n",
" elif(\"residx\" in k):\n",
" processed_feature_dict[k] = v.long()\n",
"\n", "\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" prediction_result = openfold_model(processed_feature_dict)\n", " prediction_result = openfold_model(processed_feature_dict)\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment