Unverified Commit 4aac16f6 authored by kmindspark's avatar kmindspark Committed by GitHub
Browse files

Inference colab with updated checkpoints (#8831)

parent 9ec6f6e4
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "inference_tf2_colab.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
......@@ -24,26 +37,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"id": "LBZ9VWZZFUCT",
"colab_type": "code",
"id": "LBZ9VWZZFUCT"
"colab": {}
},
"outputs": [],
"source": [
"!pip install -U --pre tensorflow==\"2.2.0\""
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"id": "oi28cqGGFWnY",
"colab_type": "code",
"id": "oi28cqGGFWnY"
"colab": {}
},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
......@@ -54,17 +65,17 @@
" os.chdir('..')\n",
"elif not pathlib.Path('models').exists():\n",
" !git clone --depth 1 https://github.com/tensorflow/models"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"id": "NwdsBdGhFanc",
"colab_type": "code",
"id": "NwdsBdGhFanc"
"colab": {}
},
"outputs": [],
"source": [
"# Install the Object Detection API\n",
"%%bash\n",
......@@ -72,33 +83,17 @@
"protoc object_detection/protos/*.proto --python_out=.\n",
"cp object_detection/packages/tf2/setup.py .\n",
"python -m pip install ."
]
},
{
"cell_type": "code",
],
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tNR8YgZVFhPm"
},
"outputs": [],
"source": [
"# Test the Object Detection API installation\n",
"%%bash\n",
"cd models/research\n",
"python object_detection/builders/model_builder_tf2_test.py"
]
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yn5_uV1HLvaz"
"id": "yn5_uV1HLvaz",
"colab": {}
},
"outputs": [],
"source": [
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
......@@ -117,7 +112,9 @@
"from object_detection.builders import model_builder\n",
"\n",
"%matplotlib inline"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
......@@ -131,13 +128,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "-y9R0Xllefec"
"id": "-y9R0Xllefec",
"colab": {}
},
"outputs": [],
"source": [
"def load_image_into_numpy_array(path):\n",
" \"\"\"Load an image from file into a numpy array.\n",
......@@ -147,7 +142,7 @@
" (height, width, channels), where channels=3 for RGB.\n",
"\n",
" Args:\n",
" path: a file path.\n",
" path: the file path to the image\n",
"\n",
" Returns:\n",
" uint8 numpy array with shape (img_height, img_width, 3)\n",
......@@ -172,24 +167,26 @@
" for edge in kp_list:\n",
" tuple_list.append((edge.start, edge.end))\n",
" return tuple_list"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"id": "R4YjnOjME1gy",
"colab_type": "code",
"id": "R4YjnOjME1gy"
"colab": {}
},
"outputs": [],
"source": [
"# @title Choose the model to use, then evaluate the cell.\n",
"MODELS = {'centernet_with_keypoints': 'center_net_resnet101_v1_fpn_512x512_kpts_coco17_tpu-8', 'centernet_without_keypoints': 'center_net_resnet101_v1_fpn_512x512_coco17_tpu-8'}\n",
"MODELS = {'centernet_with_keypoints': 'centernet_hg104_512x512_kpts_coco17_tpu-32', 'centernet_without_keypoints': 'centernet_hg104_512x512_coco17_tpu-8'}\n",
"\n",
"model_display_name = 'centernet_with_keypoints' # @param ['centernet_with_keypoints', 'centernet_without_keypoints']\n",
"model_name = MODELS[model_display_name]"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
......@@ -205,26 +202,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"id": "ctPavqlyPuU_",
"colab_type": "code",
"id": "ctPavqlyPuU_"
"colab": {}
},
"outputs": [],
"source": [
"# Download the checkpoint/ and put it into models/research/object_detection/test_data/"
]
"# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
"\n",
"if model_display_name == 'centernet_with_keypoints':\n",
" !wget http://download.tensorflow.org/models/object_detection/tf2/20200710/centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
" !tar -xf centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
" !mv centernet_hg104_512x512_kpts_coco17_tpu-32/checkpoint models/research/object_detection/test_data/\n",
"else:\n",
" !wget http://download.tensorflow.org/models/object_detection/tf2/20200710/centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
" !tar -xf centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
" !mv centernet_hg104_512x512_coco17_tpu-8/checkpoint models/research/object_detection/test_data/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4cni4SSocvP_"
"id": "4cni4SSocvP_",
"colab": {}
},
"outputs": [],
"source": [
"pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',\n",
" model_name + '.config')\n",
......@@ -239,7 +243,7 @@
"# Restore checkpoint\n",
"ckpt = tf.compat.v2.train.Checkpoint(\n",
" model=detection_model)\n",
"ckpt.restore(os.path.join(model_dir, 'ckpt-251')).expect_partial()\n",
"ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()\n",
"\n",
"def get_model_detection_function(model):\n",
" \"\"\"Get a tf.function for detection.\"\"\"\n",
......@@ -257,7 +261,9 @@
" return detect_fn\n",
"\n",
"detect_fn = get_model_detection_function(detection_model)"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
......@@ -273,13 +279,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5mucYUS6exUJ"
"id": "5mucYUS6exUJ",
"colab": {}
},
"outputs": [],
"source": [
"label_map_path = configs['eval_input_config'].label_map_path\n",
"label_map = label_map_util.load_labelmap(label_map_path)\n",
......@@ -289,7 +293,9 @@
" use_display_name=True)\n",
"category_index = label_map_util.create_category_index(categories)\n",
"label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
......@@ -315,13 +321,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "vr_Fux-gfaG9"
"id": "vr_Fux-gfaG9",
"colab": {}
},
"outputs": [],
"source": [
"image_dir = 'models/research/object_detection/test_images/'\n",
"image_path = os.path.join(image_dir, 'image2.jpg')\n",
......@@ -365,7 +369,9 @@
"plt.figure(figsize=(12,16))\n",
"plt.imshow(image_np_with_detections)\n",
"plt.show()"
]
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
......@@ -383,13 +389,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xBgYgSGMhHVi"
"id": "xBgYgSGMhHVi",
"colab": {}
},
"outputs": [],
"source": [
"if detection_model.__class__.__name__ != 'CenterNetMetaArch':\n",
" raise AssertionError('The meta-architecture for this section '\n",
......@@ -458,20 +462,9 @@
"plt.imshow(resized_heatmap_unpadded, alpha=0.7,vmin=0, vmax=160, cmap='viridis')\n",
"plt.title('Object center heatmap (class: ' + class_name + ')')\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "inference_tf2_colab.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
"execution_count": null,
"outputs": []
}
},
"nbformat": 4,
"nbformat_minor": 0
]
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment