"vscode:/vscode.git/clone" did not exist on "e7ef9a96f73d3fe127a979deeaadd2a19763a5ce"
Unverified Commit 4aac16f6 authored by kmindspark's avatar kmindspark Committed by GitHub
Browse files

Inference colab with updated checkpoints (#8831)

parent 9ec6f6e4
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "inference_tf2_colab.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -24,26 +37,24 @@ ...@@ -24,26 +37,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {}, "id": "LBZ9VWZZFUCT",
"colab_type": "code", "colab_type": "code",
"id": "LBZ9VWZZFUCT" "colab": {}
}, },
"outputs": [],
"source": [ "source": [
"!pip install -U --pre tensorflow==\"2.2.0\"" "!pip install -U --pre tensorflow==\"2.2.0\""
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {}, "id": "oi28cqGGFWnY",
"colab_type": "code", "colab_type": "code",
"id": "oi28cqGGFWnY" "colab": {}
}, },
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import pathlib\n", "import pathlib\n",
...@@ -54,17 +65,17 @@ ...@@ -54,17 +65,17 @@
" os.chdir('..')\n", " os.chdir('..')\n",
"elif not pathlib.Path('models').exists():\n", "elif not pathlib.Path('models').exists():\n",
" !git clone --depth 1 https://github.com/tensorflow/models" " !git clone --depth 1 https://github.com/tensorflow/models"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {}, "id": "NwdsBdGhFanc",
"colab_type": "code", "colab_type": "code",
"id": "NwdsBdGhFanc" "colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# Install the Object Detection API\n", "# Install the Object Detection API\n",
"%%bash\n", "%%bash\n",
...@@ -72,33 +83,17 @@ ...@@ -72,33 +83,17 @@
"protoc object_detection/protos/*.proto --python_out=.\n", "protoc object_detection/protos/*.proto --python_out=.\n",
"cp object_detection/packages/tf2/setup.py .\n", "cp object_detection/packages/tf2/setup.py .\n",
"python -m pip install ." "python -m pip install ."
] ],
},
{
"cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "outputs": []
"colab": {},
"colab_type": "code",
"id": "tNR8YgZVFhPm"
},
"outputs": [],
"source": [
"# Test the Object Detection API installation\n",
"%%bash\n",
"cd models/research\n",
"python object_detection/builders/model_builder_tf2_test.py"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "yn5_uV1HLvaz" "id": "yn5_uV1HLvaz",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"import matplotlib\n", "import matplotlib\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
...@@ -117,7 +112,9 @@ ...@@ -117,7 +112,9 @@
"from object_detection.builders import model_builder\n", "from object_detection.builders import model_builder\n",
"\n", "\n",
"%matplotlib inline" "%matplotlib inline"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -131,13 +128,11 @@ ...@@ -131,13 +128,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "-y9R0Xllefec" "id": "-y9R0Xllefec",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"def load_image_into_numpy_array(path):\n", "def load_image_into_numpy_array(path):\n",
" \"\"\"Load an image from file into a numpy array.\n", " \"\"\"Load an image from file into a numpy array.\n",
...@@ -147,7 +142,7 @@ ...@@ -147,7 +142,7 @@
" (height, width, channels), where channels=3 for RGB.\n", " (height, width, channels), where channels=3 for RGB.\n",
"\n", "\n",
" Args:\n", " Args:\n",
" path: a file path.\n", " path: the file path to the image\n",
"\n", "\n",
" Returns:\n", " Returns:\n",
" uint8 numpy array with shape (img_height, img_width, 3)\n", " uint8 numpy array with shape (img_height, img_width, 3)\n",
...@@ -172,24 +167,26 @@ ...@@ -172,24 +167,26 @@
" for edge in kp_list:\n", " for edge in kp_list:\n",
" tuple_list.append((edge.start, edge.end))\n", " tuple_list.append((edge.start, edge.end))\n",
" return tuple_list" " return tuple_list"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {}, "id": "R4YjnOjME1gy",
"colab_type": "code", "colab_type": "code",
"id": "R4YjnOjME1gy" "colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# @title Choose the model to use, then evaluate the cell.\n", "# @title Choose the model to use, then evaluate the cell.\n",
"MODELS = {'centernet_with_keypoints': 'center_net_resnet101_v1_fpn_512x512_kpts_coco17_tpu-8', 'centernet_without_keypoints': 'center_net_resnet101_v1_fpn_512x512_coco17_tpu-8'}\n", "MODELS = {'centernet_with_keypoints': 'centernet_hg104_512x512_kpts_coco17_tpu-32', 'centernet_without_keypoints': 'centernet_hg104_512x512_coco17_tpu-8'}\n",
"\n", "\n",
"model_display_name = 'centernet_with_keypoints' # @param ['centernet_with_keypoints', 'centernet_without_keypoints']\n", "model_display_name = 'centernet_with_keypoints' # @param ['centernet_with_keypoints', 'centernet_without_keypoints']\n",
"model_name = MODELS[model_display_name]" "model_name = MODELS[model_display_name]"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -205,26 +202,33 @@ ...@@ -205,26 +202,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {}, "id": "ctPavqlyPuU_",
"colab_type": "code", "colab_type": "code",
"id": "ctPavqlyPuU_" "colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# Download the checkpoint/ and put it into models/research/object_detection/test_data/" "# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
] "\n",
"if model_display_name == 'centernet_with_keypoints':\n",
" !wget http://download.tensorflow.org/models/object_detection/tf2/20200710/centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
" !tar -xf centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
" !mv centernet_hg104_512x512_kpts_coco17_tpu-32/checkpoint models/research/object_detection/test_data/\n",
"else:\n",
" !wget http://download.tensorflow.org/models/object_detection/tf2/20200710/centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
" !tar -xf centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
" !mv centernet_hg104_512x512_coco17_tpu-8/checkpoint models/research/object_detection/test_data/"
],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "4cni4SSocvP_" "id": "4cni4SSocvP_",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',\n", "pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',\n",
" model_name + '.config')\n", " model_name + '.config')\n",
...@@ -239,7 +243,7 @@ ...@@ -239,7 +243,7 @@
"# Restore checkpoint\n", "# Restore checkpoint\n",
"ckpt = tf.compat.v2.train.Checkpoint(\n", "ckpt = tf.compat.v2.train.Checkpoint(\n",
" model=detection_model)\n", " model=detection_model)\n",
"ckpt.restore(os.path.join(model_dir, 'ckpt-251')).expect_partial()\n", "ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()\n",
"\n", "\n",
"def get_model_detection_function(model):\n", "def get_model_detection_function(model):\n",
" \"\"\"Get a tf.function for detection.\"\"\"\n", " \"\"\"Get a tf.function for detection.\"\"\"\n",
...@@ -257,7 +261,9 @@ ...@@ -257,7 +261,9 @@
" return detect_fn\n", " return detect_fn\n",
"\n", "\n",
"detect_fn = get_model_detection_function(detection_model)" "detect_fn = get_model_detection_function(detection_model)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -273,13 +279,11 @@ ...@@ -273,13 +279,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "5mucYUS6exUJ" "id": "5mucYUS6exUJ",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"label_map_path = configs['eval_input_config'].label_map_path\n", "label_map_path = configs['eval_input_config'].label_map_path\n",
"label_map = label_map_util.load_labelmap(label_map_path)\n", "label_map = label_map_util.load_labelmap(label_map_path)\n",
...@@ -289,7 +293,9 @@ ...@@ -289,7 +293,9 @@
" use_display_name=True)\n", " use_display_name=True)\n",
"category_index = label_map_util.create_category_index(categories)\n", "category_index = label_map_util.create_category_index(categories)\n",
"label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)" "label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -315,13 +321,11 @@ ...@@ -315,13 +321,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "vr_Fux-gfaG9" "id": "vr_Fux-gfaG9",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"image_dir = 'models/research/object_detection/test_images/'\n", "image_dir = 'models/research/object_detection/test_images/'\n",
"image_path = os.path.join(image_dir, 'image2.jpg')\n", "image_path = os.path.join(image_dir, 'image2.jpg')\n",
...@@ -365,7 +369,9 @@ ...@@ -365,7 +369,9 @@
"plt.figure(figsize=(12,16))\n", "plt.figure(figsize=(12,16))\n",
"plt.imshow(image_np_with_detections)\n", "plt.imshow(image_np_with_detections)\n",
"plt.show()" "plt.show()"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -383,13 +389,11 @@ ...@@ -383,13 +389,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "xBgYgSGMhHVi" "id": "xBgYgSGMhHVi",
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"if detection_model.__class__.__name__ != 'CenterNetMetaArch':\n", "if detection_model.__class__.__name__ != 'CenterNetMetaArch':\n",
" raise AssertionError('The meta-architecture for this section '\n", " raise AssertionError('The meta-architecture for this section '\n",
...@@ -458,20 +462,9 @@ ...@@ -458,20 +462,9 @@
"plt.imshow(resized_heatmap_unpadded, alpha=0.7,vmin=0, vmax=160, cmap='viridis')\n", "plt.imshow(resized_heatmap_unpadded, alpha=0.7,vmin=0, vmax=160, cmap='viridis')\n",
"plt.title('Object center heatmap (class: ' + class_name + ')')\n", "plt.title('Object center heatmap (class: ' + class_name + ')')\n",
"plt.show()" "plt.show()"
]
}
], ],
"metadata": { "execution_count": null,
"colab": { "outputs": []
"collapsed_sections": [],
"name": "inference_tf2_colab.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
} }
}, ]
"nbformat": 4,
"nbformat_minor": 0
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment