"docker/vscode:/vscode.git/clone" did not exist on "864fd4f7daa00a1b4ef9183a62c623d65148e1bd"
Commit c9f2ae14 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Download model in Jupyter Notebook. (#1580)

parent fb96b71a
...@@ -26,8 +26,11 @@ ...@@ -26,8 +26,11 @@
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import os\n", "import os\n",
"import six.moves.urllib as urllib\n",
"import sys\n", "import sys\n",
"import tarfile\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"import zipfile\n",
"\n", "\n",
"from collections import defaultdict\n", "from collections import defaultdict\n",
"from io import StringIO\n", "from io import StringIO\n",
...@@ -89,7 +92,9 @@ ...@@ -89,7 +92,9 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Variables" "## Variables\n",
"\n",
"See the [detection model zoo](g3doc/detection_model_zoo.md) for a list of all models to try."
] ]
}, },
{ {
...@@ -100,8 +105,13 @@ ...@@ -100,8 +105,13 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# What model to download.\n",
"MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'\n",
"MODEL_FILE = MODEL_NAME + '.tar.gz'\n",
"DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'\n",
"\n",
"# Path to frozen detection graph. This is the actual model that is used for the object detection.\n", "# Path to frozen detection graph. This is the actual model that is used for the object detection.\n",
"PATH_TO_CKPT = os.path.join('test_ckpt', 'ssd_inception_v2.pb')\n", "PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'\n",
"\n", "\n",
"# List of the strings that is used to add correct label for each box.\n", "# List of the strings that is used to add correct label for each box.\n",
"PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')\n", "PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')\n",
...@@ -113,13 +123,39 @@ ...@@ -113,13 +123,39 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Load a (frozen) Tensorflow model into memory." "## Download Model"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"opener = urllib.request.URLopener()\n",
"opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)\n",
"tar_file = tarfile.open(MODEL_FILE)\n",
"for file in tar_file.getmembers():\n",
" file_name = os.path.basename(file.name)\n",
" if 'frozen_inference_graph.pb' in file_name:\n",
" tar_file.extract(file, os.getcwd())"
]
},
{
"cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [
"## Load a (frozen) Tensorflow model into memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"detection_graph = tf.Graph()\n", "detection_graph = tf.Graph()\n",
...@@ -142,7 +178,9 @@ ...@@ -142,7 +178,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"label_map = label_map_util.load_labelmap(PATH_TO_LABELS)\n", "label_map = label_map_util.load_labelmap(PATH_TO_LABELS)\n",
...@@ -201,6 +239,7 @@ ...@@ -201,6 +239,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"collapsed": true,
"scrolled": true "scrolled": true
}, },
"outputs": [], "outputs": [],
...@@ -237,25 +276,34 @@ ...@@ -237,25 +276,34 @@
" plt.figure(figsize=IMAGE_SIZE)\n", " plt.figure(figsize=IMAGE_SIZE)\n",
" plt.imshow(image_np)" " plt.imshow(image_np)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 2",
"language": "python", "language": "python",
"name": "python3" "name": "python2"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
"name": "ipython", "name": "ipython",
"version": 3 "version": 2
}, },
"file_extension": ".py", "file_extension": ".py",
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython2",
"version": "3.4.3" "version": "2.7.13"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -395,7 +395,7 @@ def visualize_boxes_and_labels_on_image_array(image, ...@@ -395,7 +395,7 @@ def visualize_boxes_and_labels_on_image_array(image,
classes[i] % len(STANDARD_COLORS)] classes[i] % len(STANDARD_COLORS)]
# Draw all boxes onto image. # Draw all boxes onto image.
for box, color in box_to_color_map.iteritems(): for box, color in six.iteritems(box_to_color_map):
ymin, xmin, ymax, xmax = box ymin, xmin, ymax, xmax = box
if instance_masks is not None: if instance_masks is not None:
draw_mask_on_image_array( draw_mask_on_image_array(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment