Commit c9f2ae14 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Download model in Jupyter Notebook. (#1580)

parent fb96b71a
......@@ -26,8 +26,11 @@
"source": [
"import numpy as np\n",
"import os\n",
"import six.moves.urllib as urllib\n",
"import sys\n",
"import tarfile\n",
"import tensorflow as tf\n",
"import zipfile\n",
"\n",
"from collections import defaultdict\n",
"from io import StringIO\n",
......@@ -89,7 +92,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Variables"
"## Variables\n",
"\n",
"See the [detection model zoo](g3doc/detection_model_zoo.md) for a list of all models to try."
]
},
{
......@@ -100,8 +105,13 @@
},
"outputs": [],
"source": [
"# What model to download.\n",
"MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'\n",
"MODEL_FILE = MODEL_NAME + '.tar.gz'\n",
"DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'\n",
"\n",
"# Path to frozen detection graph. This is the actual model that is used for the object detection.\n",
"PATH_TO_CKPT = os.path.join('test_ckpt', 'ssd_inception_v2.pb')\n",
"PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'\n",
"\n",
"# List of the strings that is used to add correct label for each box.\n",
"PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')\n",
......@@ -113,13 +123,39 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load a (frozen) Tensorflow model into memory."
"## Download Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"opener = urllib.request.URLopener()\n",
"opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)\n",
"tar_file = tarfile.open(MODEL_FILE)\n",
"for file in tar_file.getmembers():\n",
" file_name = os.path.basename(file.name)\n",
" if 'frozen_inference_graph.pb' in file_name:\n",
" tar_file.extract(file, os.getcwd())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load a (frozen) Tensorflow model into memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"detection_graph = tf.Graph()\n",
......@@ -142,7 +178,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"label_map = label_map_util.load_labelmap(PATH_TO_LABELS)\n",
......@@ -201,6 +239,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
......@@ -237,25 +276,34 @@
" plt.figure(figsize=IMAGE_SIZE)\n",
" plt.imshow(image_np)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 2",
"language": "python",
"name": "python3"
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
......
......@@ -395,7 +395,7 @@ def visualize_boxes_and_labels_on_image_array(image,
classes[i] % len(STANDARD_COLORS)]
# Draw all boxes onto image.
for box, color in box_to_color_map.iteritems():
for box, color in six.iteritems(box_to_color_map):
ymin, xmin, ymax, xmax = box
if instance_masks is not None:
draw_mask_on_image_array(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment