Commit 9e1eb9fd authored by Hanzi Mao's avatar Hanzi Mao
Browse files

update notebooks to specify model_type explicitly

parent c3b8a88a
...@@ -214,19 +214,6 @@ ...@@ -214,19 +214,6 @@
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended." "To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended."
] ]
}, },
{
"cell_type": "code",
"execution_count": 9,
"id": "17ade22d",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
...@@ -238,6 +225,11 @@ ...@@ -238,6 +225,11 @@
"sys.path.append(\"..\")\n", "sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n", "from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
"\n", "\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n", "sam.to(device=device)\n",
"\n", "\n",
...@@ -446,7 +438,7 @@ ...@@ -446,7 +438,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,
...@@ -192,7 +192,7 @@ ...@@ -192,7 +192,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"checkpoint = \"sam_vit_h_4b8939.pth\"\n", "checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"default\"" "model_type = \"vit_h\""
] ]
}, },
{ {
...@@ -766,7 +766,7 @@ ...@@ -766,7 +766,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -229,18 +229,6 @@ ...@@ -229,18 +229,6 @@
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results." "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
] ]
}, },
{
"cell_type": "code",
"execution_count": 9,
"id": "17ccff22",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
...@@ -252,6 +240,11 @@ ...@@ -252,6 +240,11 @@
"sys.path.append(\"..\")\n", "sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamPredictor\n", "from segment_anything import sam_model_registry, SamPredictor\n",
"\n", "\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n", "sam.to(device=device)\n",
"\n", "\n",
...@@ -1015,7 +1008,7 @@ ...@@ -1015,7 +1008,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment