"...text-generation-inference.git" did not exist on "0ad7f6f87d682496c4ff617608623db724935859"
Unverified Commit 950ebc45 authored by kmindspark's avatar kmindspark Committed by GitHub
Browse files

Interactive Ducks Colab (#8821)

* add ducks

* add file

* add file

* add images

* cleaning up to test

* add colab

* latest

* add colab

* clean up pr

* change paths

* fix colab

* rename colab

* remove config

* fix more things in colab

* clear outputs from colab

* remove todos

* for testing purposes

* for testing purposes

* PR for interactive ducks

* add colab utils file

* add colab utils file

* add colab utils file

* add colab utils file

* add colab utils file

* add colab utils file

* add colab utils

* add separate utils file

* edit description

* final

* fix git repo url and remove installation test

* temp config

* add model checkpoint download

* remove config
parent f5184e5f
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"nbformat_minor": 0, "nbformat_minor": 0,
"metadata": { "metadata": {
"colab": { "colab": {
"name": "eager_few_shot_od_training_tf2_colab.ipynb", "name": "interactive_eager_few_shot_od_training_colab.ipynb",
"provenance": [], "provenance": [],
"collapsed_sections": [] "collapsed_sections": []
}, },
...@@ -26,25 +26,26 @@ ...@@ -26,25 +26,26 @@
"Welcome to the Eager Few Shot Object Detection Colab --- in this colab we demonstrate fine tuning of a (TF2 friendly) RetinaNet architecture on very few examples of a novel class after initializing from a pre-trained COCO checkpoint.\n", "Welcome to the Eager Few Shot Object Detection Colab --- in this colab we demonstrate fine tuning of a (TF2 friendly) RetinaNet architecture on very few examples of a novel class after initializing from a pre-trained COCO checkpoint.\n",
"Training runs in eager mode.\n", "Training runs in eager mode.\n",
"\n", "\n",
"To run this colab faster, you can choose a GPU runtime via Runtime -> Change runtime type.\n", "To run this colab, you will need to connect to a borg runtime (we recommend\n",
"Tensorflow GPU with Python 3).\n",
"\n", "\n",
"Estimated time to run through this colab (with GPU): < 5 minutes." "Estimated time to run through this colab: < 5 minutes."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "YzEJA8Gapg4o", "colab_type": "text",
"colab_type": "text" "id": "vPs64QA1Zdov"
}, },
"source": [ "source": [
"## Imports and Setup" "## Imports"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "AFkb2D7RpgM1", "id": "LBZ9VWZZFUCT",
"colab_type": "code", "colab_type": "code",
"colab": {} "colab": {}
}, },
...@@ -57,7 +58,7 @@ ...@@ -57,7 +58,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "3h-Rd2YpqV8m", "id": "oi28cqGGFWnY",
"colab_type": "code", "colab_type": "code",
"colab": {} "colab": {}
}, },
...@@ -70,7 +71,7 @@ ...@@ -70,7 +71,7 @@
" while \"models\" in pathlib.Path.cwd().parts:\n", " while \"models\" in pathlib.Path.cwd().parts:\n",
" os.chdir('..')\n", " os.chdir('..')\n",
"elif not pathlib.Path('models').exists():\n", "elif not pathlib.Path('models').exists():\n",
" !git clone --depth 1 https://github.com/tensorflow/models\n" " !git clone --depth 1 https://github.com/tensorflow/models"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -78,7 +79,7 @@ ...@@ -78,7 +79,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "JC7QeY4nqWBF", "id": "NwdsBdGhFanc",
"colab_type": "code", "colab_type": "code",
"colab": {} "colab": {}
}, },
...@@ -93,27 +94,11 @@ ...@@ -93,27 +94,11 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "code",
"metadata": {
"id": "HtYJEX-MoRbb",
"colab_type": "code",
"colab": {}
},
"source": [
"# Test the Object Detection API installation\n",
"%%bash\n",
"cd models/research\n",
"python object_detection/builders/model_builder_tf2_test.py"
],
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"colab_type": "code", "colab_type": "code",
"id": "yn5_uV1HLvaz", "id": "uZcqD4NLdnf4",
"colab": {} "colab": {}
}, },
"source": [ "source": [
...@@ -121,18 +106,24 @@ ...@@ -121,18 +106,24 @@
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"import os\n", "import os\n",
"import io\n",
"import random\n", "import random\n",
"import io\n",
"import imageio\n",
"import glob\n",
"import scipy.misc\n",
"import numpy as np\n", "import numpy as np\n",
"from six import BytesIO\n", "from six import BytesIO\n",
"from PIL import Image, ImageDraw, ImageFont\n", "from PIL import Image, ImageDraw, ImageFont\n",
"from IPython.display import display, Javascript\n",
"from IPython.display import Image as IPyImage\n",
"\n", "\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from object_detection.utils import label_map_util\n",
"from object_detection.utils import config_util\n", "from object_detection.utils import config_util\n",
"from object_detection.utils import visualization_utils as viz_utils\n", "from object_detection.utils import visualization_utils as viz_utils\n",
"from object_detection.utils import colab_utils\n",
"from object_detection.builders import model_builder\n", "from object_detection.builders import model_builder\n",
"from object_detection.protos import model_pb2\n",
"\n", "\n",
"%matplotlib inline" "%matplotlib inline"
], ],
...@@ -146,7 +137,7 @@ ...@@ -146,7 +137,7 @@
"id": "IogyryF2lFBL" "id": "IogyryF2lFBL"
}, },
"source": [ "source": [
"## Utilities" "# Utilities"
] ]
}, },
{ {
...@@ -180,7 +171,7 @@ ...@@ -180,7 +171,7 @@
" boxes,\n", " boxes,\n",
" classes,\n", " classes,\n",
" scores,\n", " scores,\n",
" category_index,\n", " category_index, \n",
" figsize=(12, 16),\n", " figsize=(12, 16),\n",
" image_name=None):\n", " image_name=None):\n",
" \"\"\"Wrapper function to visualize detections.\n", " \"\"\"Wrapper function to visualize detections.\n",
...@@ -195,7 +186,6 @@ ...@@ -195,7 +186,6 @@
" boxes and plot all boxes as black with no classes or scores.\n", " boxes and plot all boxes as black with no classes or scores.\n",
" category_index: a dict containing category dictionaries (each holding\n", " category_index: a dict containing category dictionaries (each holding\n",
" category index `id` and category name `name`) keyed by category indices.\n", " category index `id` and category name `name`) keyed by category indices.\n",
" figsize: pair of ints indicating width, height (inches)\n",
" \"\"\"\n", " \"\"\"\n",
" image_np_with_annotations = image_np.copy()\n", " image_np_with_annotations = image_np.copy()\n",
" viz_utils.visualize_boxes_and_labels_on_image_array(\n", " viz_utils.visualize_boxes_and_labels_on_image_array(\n",
...@@ -206,11 +196,10 @@ ...@@ -206,11 +196,10 @@
" category_index,\n", " category_index,\n",
" use_normalized_coordinates=True,\n", " use_normalized_coordinates=True,\n",
" min_score_thresh=0.8)\n", " min_score_thresh=0.8)\n",
" plt.figure(figsize=figsize)\n",
" if (image_name):\n", " if (image_name):\n",
" plt.imsave(image_name, image_np_with_annotations)\n", " plt.imsave(image_name, image_np_with_annotations)\n",
" else:\n", " else:\n",
" plt.imshow(image_np_with_annotations)" " plt.imshow(image_np_with_annotations)\n"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -224,45 +213,132 @@ ...@@ -224,45 +213,132 @@
"source": [ "source": [
"# Rubber Ducky data\n", "# Rubber Ducky data\n",
"\n", "\n",
"Here is some toy (literally) data consisting of 5 annotated images of a rubber\n", "We will start with some toy (literally) data consisting of 5 images of a rubber\n",
"ducky. For simplicity, we explicitly write out the bounding box data in this cell. Note that the [coco](https://cocodataset.org/#explore) dataset contains a number of animals, but notably, it does *not* contain rubber duckies (or even ducks for that matter), so this is a novel class." "ducky. Note that the [coco](https://cocodataset.org/#explore) dataset contains a number of animals, but notably, it does *not* contain rubber duckies (or even ducks for that matter), so this is a novel class."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"colab_type": "code", "colab_type": "code",
"id": "XePU382-vrou", "id": "SQy3ND7EpFQM",
"colab": {} "colab": {}
}, },
"source": [ "source": [
"# Load images\n", "# Load images and visualize\n",
"train_image_dir = 'models/research/object_detection/test_images/ducky/train/'\n", "train_image_dir = 'models/research/object_detection/test_images/ducky/train/'\n",
"train_images_np = {}\n", "train_images_np = []\n",
"for i in range(1, 6):\n", "for i in range(1, 6):\n",
" image_path = os.path.join(train_image_dir, 'robertducky' + str(i) + '.jpg')\n", " image_path = os.path.join(train_image_dir, 'robertducky' + str(i) + '.jpg')\n",
" train_images_np[i-1] = np.expand_dims(\n", " train_images_np.append(load_image_into_numpy_array(image_path))\n",
" load_image_into_numpy_array(image_path), axis=0)\n", "\n",
"\n", "plt.rcParams['axes.grid'] = False\n",
"# Annotations (bounding boxes and classes) for each image\n", "plt.rcParams['xtick.labelsize'] = False\n",
"# As is standard in the Object Detection API, boxes are listed in \n", "plt.rcParams['ytick.labelsize'] = False\n",
"# [ymin, xmin, ymax, xmax] format using normalized coordinates (relative to\n", "plt.rcParams['xtick.top'] = False\n",
"# the width and height of the image).\n", "plt.rcParams['xtick.bottom'] = False\n",
"gt_boxes = {\n", "plt.rcParams['ytick.left'] = False\n",
" 0: np.array([[0.436, 0.591, 0.629, 0.712]], dtype=np.float32),\n", "plt.rcParams['ytick.right'] = False\n",
" 1: np.array([[0.539, 0.583, 0.73, 0.71]], dtype=np.float32),\n", "plt.rcParams['figure.figsize'] = [14, 7]\n",
" 2: np.array([[0.464, 0.414, 0.626, 0.548]], dtype=np.float32),\n", "\n",
" 3: np.array([[0.313, 0.308, 0.648, 0.526]], dtype=np.float32),\n", "for idx, train_image_np in enumerate(train_images_np):\n",
" 4: np.array([[0.256, 0.444, 0.484, 0.629]], dtype=np.float32)\n", " plt.subplot(2,3, idx+1)\n",
"}\n", " plt.imshow(train_image_np)\n",
"plt.show()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cbKXmQoxcUgE"
},
"source": [
"# Annotate images with bounding boxes\n",
"\n",
"In this cell you will annotate the rubber duckies --- draw a box around the rubber ducky in each image; click `next image` to go to the next image and `submit` when there are no more images.\n",
"\n",
"If you'd like to skip the manual annotation step, we totally understand. In this case, simply skip this cell and run the next cell instead, where we've prepopulated the groundtruth with pre-annotated bounding boxes.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "-nEDRoUEcUgL",
"colab": {}
},
"source": [
"gt_boxes = []\n",
"colab_utils.annotate(train_images_np, box_storage_pointer=gt_boxes)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wTP9AFqecUgS"
},
"source": [
"# In case you didn't want to label...\n",
"\n",
"Run this cell only if you didn't annotate anything above and\n",
"would prefer to just use our preannotated boxes. Don't forget\n",
"to uncomment."
]
},
{
"cell_type": "code",
"metadata": {
"id": "wIAT6ZUmdHOC",
"colab_type": "code",
"colab": {}
},
"source": [
"# gt_boxes = [\n",
"# np.array([[0.436, 0.591, 0.629, 0.712]], dtype=np.float32),\n",
"# np.array([[0.539, 0.583, 0.73, 0.71]], dtype=np.float32),\n",
"# np.array([[0.464, 0.414, 0.626, 0.548]], dtype=np.float32),\n",
"# np.array([[0.313, 0.308, 0.648, 0.526]], dtype=np.float32),\n",
"# np.array([[0.256, 0.444, 0.484, 0.629]], dtype=np.float32)\n",
"# ]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Dqb_yjAo3cO_"
},
"source": [
"# Prepare data for training\n",
"\n",
"Below we add the class annotations (for simplicity, we assume a single class in this colab; though it should be straightforward to extend this to handle multiple classes). We also convert everything to the format that the training\n",
"loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.)."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "HWBqFVMcweF-",
"colab": {}
},
"source": [
"\n", "\n",
"# By convention, our non-background classes start counting at 1. Given\n", "# By convention, our non-background classes start counting at 1. Given\n",
"# that we will be predicting just one class, we will therefore assign it a\n", "# that we will be predicting just one class, we will therefore assign it a\n",
"# `class id` of 1.\n", "# `class id` of 1.\n",
"duck_class_id = 1\n", "duck_class_id = 1\n",
"num_classes = 1\n", "num_classes = 1\n",
"gt_classes = {\n", "\n",
" i: np.array([duck_class_id], dtype=np.int32) for i in range(5)}\n",
"category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}\n", "category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}\n",
"\n", "\n",
"# Convert class labels to one-hot; convert everything to tensors.\n", "# Convert class labels to one-hot; convert everything to tensors.\n",
...@@ -271,18 +347,19 @@ ...@@ -271,18 +347,19 @@
"# classes start counting at the zeroth index. This is ordinarily just handled\n", "# classes start counting at the zeroth index. This is ordinarily just handled\n",
"# automatically in our training binaries, but we need to reproduce it here.\n", "# automatically in our training binaries, but we need to reproduce it here.\n",
"label_id_offset = 1\n", "label_id_offset = 1\n",
"train_image_tensors = {}\n", "train_image_tensors = []\n",
"gt_classes_one_hot_tensors = {}\n", "gt_classes_one_hot_tensors = []\n",
"gt_box_tensors = {}\n", "gt_box_tensors = []\n",
"for id in train_images_np:\n", "for (train_image_np, gt_box_np) in zip(\n",
" train_image_tensors[id] = tf.convert_to_tensor(\n", " train_images_np, gt_boxes):\n",
" train_images_np[id], dtype=tf.float32)\n", " train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(\n",
" gt_box_tensors[id] = tf.convert_to_tensor(gt_boxes[id])\n", " train_image_np, dtype=tf.float32), axis=0))\n",
" gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))\n",
" zero_indexed_groundtruth_classes = tf.convert_to_tensor(\n", " zero_indexed_groundtruth_classes = tf.convert_to_tensor(\n",
" gt_classes[id] - label_id_offset)\n", " np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)\n",
" gt_classes_one_hot_tensors[id] = tf.one_hot(\n", " gt_classes_one_hot_tensors.append(tf.one_hot(\n",
" zero_indexed_groundtruth_classes, num_classes)\n", " zero_indexed_groundtruth_classes, num_classes))\n",
"print('Done prepping data.')" "print('Done prepping data.')\n"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -306,11 +383,16 @@ ...@@ -306,11 +383,16 @@
}, },
"source": [ "source": [
"dummy_scores = np.array([1.0], dtype=np.float32) # give boxes a score of 100%\n", "dummy_scores = np.array([1.0], dtype=np.float32) # give boxes a score of 100%\n",
"for i in range(5):\n", "\n",
"plt.figure(figsize=(30, 15))\n",
"for idx in range(5):\n",
" plt.subplot(2,3, idx+1)\n",
" plot_detections(\n", " plot_detections(\n",
" train_images_np[i][0],\n", " train_images_np[idx],\n",
" gt_boxes[i], gt_classes[i], dummy_scores, category_index)\n", " gt_boxes[idx],\n",
" plt.show()" " np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),\n",
" dummy_scores, category_index)\n",
"plt.show()"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -332,12 +414,16 @@ ...@@ -332,12 +414,16 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "Yq0tLasBwfsd", "id": "9J16r3NChD-7",
"colab_type": "code", "colab_type": "code",
"colab": {} "colab": {}
}, },
"source": [ "source": [
"# Download the checkpoint/ and put it into models/research/object_detection/test_data/" "# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
"\n",
"!wget http://download.tensorflow.org/models/object_detection/tf2/20200710/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz\n",
"!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz\n",
"!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
...@@ -355,7 +441,7 @@ ...@@ -355,7 +441,7 @@
"print('Building model and restoring weights for fine-tuning...', flush=True)\n", "print('Building model and restoring weights for fine-tuning...', flush=True)\n",
"num_classes = 1\n", "num_classes = 1\n",
"pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'\n", "pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'\n",
"checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-26'\n", "checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'\n",
"\n", "\n",
"# Load pipeline config and build a detection model.\n", "# Load pipeline config and build a detection model.\n",
"#\n", "#\n",
...@@ -469,7 +555,7 @@ ...@@ -469,7 +555,7 @@
" for image_tensor in image_tensors], axis=0)\n", " for image_tensor in image_tensors], axis=0)\n",
" prediction_dict = model.predict(preprocessed_images, shapes)\n", " prediction_dict = model.predict(preprocessed_images, shapes)\n",
" losses_dict = model.loss(prediction_dict, shapes)\n", " losses_dict = model.loss(prediction_dict, shapes)\n",
" total_loss = tf.add_n(losses_dict.values())\n", " total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']\n",
" gradients = tape.gradient(total_loss, vars_to_fine_tune)\n", " gradients = tape.gradient(total_loss, vars_to_fine_tune)\n",
" optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))\n", " optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))\n",
" return total_loss\n", " return total_loss\n",
...@@ -483,7 +569,7 @@ ...@@ -483,7 +569,7 @@
"print('Start fine-tuning!', flush=True)\n", "print('Start fine-tuning!', flush=True)\n",
"for idx in range(num_batches):\n", "for idx in range(num_batches):\n",
" # Grab keys for a random subset of examples\n", " # Grab keys for a random subset of examples\n",
" all_keys = sorted(train_images_np.keys())\n", " all_keys = list(range(len(train_images_np)))\n",
" random.shuffle(all_keys)\n", " random.shuffle(all_keys)\n",
" example_keys = all_keys[:batch_size]\n", " example_keys = all_keys[:batch_size]\n",
"\n", "\n",
...@@ -572,16 +658,11 @@ ...@@ -572,16 +658,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
"id": "PsmKjGxBPqed", "id": "RW1FrT2iNnpy",
"colab_type": "code", "colab_type": "code",
"colab": {} "colab": {}
}, },
"source": [ "source": [
"import IPython\n",
"from IPython import display\n",
"import imageio\n",
"import glob\n",
"\n",
"imageio.plugins.freeimage.download()\n", "imageio.plugins.freeimage.download()\n",
"\n", "\n",
"anim_file = 'duckies_test.gif'\n", "anim_file = 'duckies_test.gif'\n",
...@@ -596,7 +677,7 @@ ...@@ -596,7 +677,7 @@
"\n", "\n",
"imageio.mimsave(anim_file, images, 'GIF-FI', fps=5)\n", "imageio.mimsave(anim_file, images, 'GIF-FI', fps=5)\n",
"\n", "\n",
"display.display(display.Image(open(anim_file, 'rb').read()))\n" "display(IPyImage(open(anim_file, 'rb').read()))"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for colab tutorials located in object_detection/colab_tutorials/"""
import os
import random
import uuid
import io
import operator
import json
from typing import List, Dict, Union
from base64 import b64decode, b64encode
from IPython.display import display, Javascript
from google.colab import output
from google.colab.output import eval_js
import tensorflow as tf
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
def image_from_numpy(image):
"""
Open an image at the specified path and encode it in Base64.
Parameters
----------
image: np.ndarray
Image represented as a numpy array
Returns
-------
str
Encoded Base64 representation of the image
"""
with io.BytesIO() as output:
Image.fromarray(image).save(output, format="JPEG")
data = output.getvalue()
data = str(b64encode(data))[2:-1]
return data
def draw_bbox(image_urls, callbackId):
"""
Open the bounding box UI and send the results to a callback function.
Parameters
----------
image_urls: list[str | np.ndarray]
List of locations from where to load the images from. If a np.ndarray is
given, the array is interpretted as an image and sent to the frontend. If
a str is given, the string is interpreted as a path and is read as a
np.ndarray before being sent to the frontend.
callbackId: str
The ID for the callback function to send the bounding box results to
when the user hits submit.
"""
js = Javascript('''
async function load_image(imgs, callbackId) {
//init organizational elements
const div = document.createElement('div');
var image_cont = document.createElement('div');
var errorlog = document.createElement('div');
var crosshair_h = document.createElement('div');
crosshair_h.style.position = "absolute";
crosshair_h.style.backgroundColor = "transparent";
crosshair_h.style.width = "100%";
crosshair_h.style.height = "0px";
crosshair_h.style.zIndex = 9998;
crosshair_h.style.borderStyle = "dotted";
crosshair_h.style.borderWidth = "2px";
crosshair_h.style.borderColor = "rgba(255, 0, 0, 0.75)";
crosshair_h.style.cursor = "crosshair";
var crosshair_v = document.createElement('div');
crosshair_v.style.position = "absolute";
crosshair_v.style.backgroundColor = "transparent";
crosshair_v.style.width = "0px";
crosshair_v.style.height = "100%";
crosshair_v.style.zIndex = 9999;
crosshair_v.style.top = "0px";
crosshair_v.style.borderStyle = "dotted";
crosshair_v.style.borderWidth = "2px";
crosshair_v.style.borderColor = "rgba(255, 0, 0, 0.75)";
crosshair_v.style.cursor = "crosshair";
crosshair_v.style.marginTop = "23px";
var brdiv = document.createElement('br');
//init control elements
var next = document.createElement('button');
var prev = document.createElement('button');
var submit = document.createElement('button');
var deleteButton = document.createElement('button');
var deleteAllbutton = document.createElement('button');
//init image containers
var image = new Image();
var canvas_img = document.createElement('canvas');
var ctx = canvas_img.getContext("2d");
canvas_img.style.cursor = "crosshair";
canvas_img.setAttribute('draggable', false);
crosshair_v.setAttribute('draggable', false);
crosshair_h.setAttribute('draggable', false);
// bounding box containers
const height = 600
var allBoundingBoxes = [];
var curr_image = 0
var im_height = 0;
var im_width = 0;
//initialize bounding boxes
for (var i = 0; i < imgs.length; i++) {
allBoundingBoxes[i] = [];
}
//initialize image view
errorlog.id = 'errorlog';
image.style.display = 'block';
image.setAttribute('draggable', false);
//load the first image
img = imgs[curr_image];
image.src = "data:image/png;base64," + img;
image.onload = function() {
// normalize display height and canvas
image.height = height;
image_cont.height = canvas_img.height = image.height;
image_cont.width = canvas_img.width = image.naturalWidth;
crosshair_v.style.height = image_cont.height + "px";
crosshair_h.style.width = image_cont.width + "px";
// draw the new image
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
};
// move to next image in array
next.textContent = "next image";
next.onclick = function(){
if (curr_image < imgs.length - 1){
// clear canvas and load new image
curr_image += 1;
errorlog.innerHTML = "";
}
else{
errorlog.innerHTML = "All images completed!!";
}
resetcanvas();
}
//move forward through list of images
prev.textContent = "prev image"
prev.onclick = function(){
if (curr_image > 0){
// clear canvas and load new image
curr_image -= 1;
errorlog.innerHTML = "";
}
else{
errorlog.innerHTML = "at the beginning";
}
resetcanvas();
}
// on delete, deletes the last bounding box
deleteButton.textContent = "undo bbox";
deleteButton.onclick = function(){
boundingBoxes.pop();
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
image.src = "data:image/png;base64," + img;
image.onload = function() {
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
boundingBoxes.map(r => {drawRect(r)});
};
}
// on all delete, deletes all of the bounding box
deleteAllbutton.textContent = "delete all"
deleteAllbutton.onclick = function(){
boundingBoxes = [];
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
image.src = "data:image/png;base64," + img;
image.onload = function() {
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
//boundingBoxes.map(r => {drawRect(r)});
};
}
// on submit, send the boxes to display
submit.textContent = "submit";
submit.onclick = function(){
errorlog.innerHTML = "";
// send box data to callback fucntion
google.colab.kernel.invokeFunction(callbackId, [allBoundingBoxes], {});
}
// init template for annotations
const annotation = {
x: 0,
y: 0,
w: 0,
h: 0,
};
// the array of all rectangles
let boundingBoxes = allBoundingBoxes[curr_image];
// the actual rectangle, the one that is being drawn
let o = {};
// a variable to store the mouse position
let m = {},
// a variable to store the point where you begin to draw the
// rectangle
start = {};
// a boolean variable to store the drawing state
let isDrawing = false;
var elem = null;
function handleMouseDown(e) {
// on mouse click set change the cursor and start tracking the mouse position
start = oMousePos(canvas_img, e);
// configure is drawing to true
isDrawing = true;
}
function handleMouseMove(e) {
// move crosshairs, but only within the bounds of the canvas
if (document.elementsFromPoint(e.pageX, e.pageY).includes(canvas_img)) {
crosshair_h.style.top = e.pageY + "px";
crosshair_v.style.left = e.pageX + "px";
}
// move the bounding box
if(isDrawing){
m = oMousePos(canvas_img, e);
draw();
}
}
function handleMouseUp(e) {
if (isDrawing) {
// on mouse release, push a bounding box to array and draw all boxes
isDrawing = false;
const box = Object.create(annotation);
// calculate the position of the rectangle
if (o.w > 0){
box.x = o.x;
}
else{
box.x = o.x + o.w;
}
if (o.h > 0){
box.y = o.y;
}
else{
box.y = o.y + o.h;
}
box.w = Math.abs(o.w);
box.h = Math.abs(o.h);
// add the bounding box to the image
boundingBoxes.push(box);
draw();
}
}
function draw() {
o.x = (start.x)/image.width; // start position of x
o.y = (start.y)/image.height; // start position of y
o.w = (m.x - start.x)/image.width; // width
o.h = (m.y - start.y)/image.height; // height
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
// draw all the rectangles saved in the rectsRy
boundingBoxes.map(r => {drawRect(r)});
// draw the actual rectangle
drawRect(o);
}
// add the handlers needed for dragging
crosshair_h.addEventListener("mousedown", handleMouseDown);
crosshair_v.addEventListener("mousedown", handleMouseDown);
document.addEventListener("mousemove", handleMouseMove);
document.addEventListener("mouseup", handleMouseUp);
function resetcanvas(){
// clear canvas
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
img = imgs[curr_image]
image.src = "data:image/png;base64," + img;
// onload init new canvas and display image
image.onload = function() {
// normalize display height and canvas
image.height = height;
image_cont.height = canvas_img.height = image.height;
image_cont.width = canvas_img.width = image.naturalWidth;
crosshair_v.style.height = image_cont.height + "px";
crosshair_h.style.width = image_cont.width + "px";
// draw the new image
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
// draw bounding boxes
boundingBoxes = allBoundingBoxes[curr_image];
boundingBoxes.map(r => {drawRect(r)});
};
}
function drawRect(o){
// draw a predefined rectangle
ctx.strokeStyle = "red";
ctx.lineWidth = 2;
ctx.beginPath(o);
ctx.rect(o.x * image.width, o.y * image.height, o.w * image.width, o.h * image.height);
ctx.stroke();
}
// Function to detect the mouse position
function oMousePos(canvas_img, evt) {
let ClientRect = canvas_img.getBoundingClientRect();
return {
x: evt.clientX - ClientRect.left,
y: evt.clientY - ClientRect.top
};
}
//configure colab output display
google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
//build the html document that will be seen in output
div.appendChild(document.createElement('br'))
div.appendChild(image_cont)
image_cont.appendChild(canvas_img)
image_cont.appendChild(crosshair_h)
image_cont.appendChild(crosshair_v)
div.appendChild(document.createElement('br'))
div.appendChild(errorlog)
div.appendChild(prev)
div.appendChild(next)
div.appendChild(deleteButton)
div.appendChild(deleteAllbutton)
div.appendChild(document.createElement('br'))
div.appendChild(brdiv)
div.appendChild(submit)
document.querySelector("#output-area").appendChild(div);
return
}''')
#load the images as a byte array
bytearrays = []
for image in image_urls:
if isinstance(image, str):
bytearrays.append(image_from_path(image))
elif isinstance(image, np.ndarray):
bytearrays.append(image_from_numpy(image))
else:
raise TypeError(f"Image has unsupported type {type(image)}. Only str and np.ndarray are supported.")
#format arrays for input
image_data = json.dumps(bytearrays)
del bytearrays
#call java script function pass string byte array(image_data) as input
display(js)
eval_js(f"load_image({image_data}, '{callbackId}')")
return
def annotate(imgs: List[Union[str, np.ndarray]], box_storage_pointer: List[np.ndarray], callbackId: str = None):
"""
Open the bounding box UI and prompt the user for input.
Parameters
----------
imgs: list[str | np.ndarray]
List of locations from where to load the images from. If a np.ndarray is
given, the array is interpretted as an image and sent to the frontend. If
a str is given, the string is interpreted as a path and is read as a
np.ndarray before being sent to the frontend.
box_storage_pointer: list[np.ndarray]
Destination list for bounding box arrays. Each array in this list
corresponds to one of the images given in imgs. The array is a
N x 4 array where N is the number of bounding boxes given by the user
for that particular image. If there are no bounding boxes for an image,
None is used instead of an empty array.
callbackId: str, optional
The ID for the callback function that communicates between the fontend
and the backend. If no ID is given, a random UUID string is used instead.
"""
# Set a random ID for the callback function
if callbackId is None:
callbackId = str(uuid.uuid1()).replace('-', '')
def dictToList(input):
'''
This function converts the dictionary from the frontend (if the format
{x, y, w, h} as shown in callbackFunction) into a list
([y_min, x_min, y_max, x_max])
'''
return (input['y'], input['x'], input['y'] + input['h'], input['x'] + input['w'])
def callbackFunction(annotations: List[List[Dict[str, float]]]):
"""
This is the call back function to capture the data from the frontend and
convert the data into a numpy array.
Parameters
----------
annotations: list[list[dict[str, float]]]
The input of the call back function is a list of list of objects
corresponding to the annotations. The format of annotations is shown
below
[
// stuff for image 1
[
// stuff for rect 1
{x, y, w, h},
// stuff for rect 2
{x, y, w, h},
...
],
// stuff for image 2
[
// stuff for rect 1
{x, y, w, h},
// stuff for rect 2
{x, y, w, h},
...
],
...
]
"""
# reset the boxes list
nonlocal box_storage_pointer
boxes: List[np.ndarray] = box_storage_pointer
boxes.clear()
# load the new annotations into the boxes list
for annotationsPerImg in annotations:
rectanglesAsArrays = [np.clip(dictToList(annotation), 0, 1) for annotation in annotationsPerImg]
if rectanglesAsArrays:
boxes.append(np.stack(rectanglesAsArrays))
else:
boxes.append(None)
# output the annotations to the errorlog
with output.redirect_to_element('#errorlog'):
display("--boxes array populated--")
output.register_callback(callbackId, callbackFunction)
draw_bbox(imgs, callbackId)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment