"Welcome to the Eager Few Shot Object Detection Colab --- in this colab we demonstrate fine tuning of a (TF2 friendly) RetinaNet architecture on very few examples of a novel class after initializing from a pre-trained COCO checkpoint.\n",
"Training runs in eager mode.\n",
"\n",
"Estimated time to run through this colab (with GPU): \u003c 5 minutes."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vPs64QA1Zdov"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LBZ9VWZZFUCT"
},
"outputs": [],
"source": [
"!pip install -U --pre tensorflow==\"2.2.0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "oi28cqGGFWnY"
},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# Clone the tensorflow models repository if it doesn't already exist\n",
"if \"models\" in pathlib.Path.cwd().parts:\n",
" while \"models\" in pathlib.Path.cwd().parts:\n",
"We will start with some toy (literally) data consisting of 5 images of a rubber\n",
"ducky. Note that the [coco](https://cocodataset.org/#explore) dataset contains a number of animals, but notably, it does *not* contain rubber duckies (or even ducks for that matter), so this is a novel class."
"for idx, train_image_np in enumerate(train_images_np):\n",
" plt.subplot(2, 3, idx+1)\n",
" plt.imshow(train_image_np)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cbKXmQoxcUgE"
},
"source": [
"# Annotate images with bounding boxes\n",
"\n",
"In this cell you will annotate the rubber duckies --- draw a box around the rubber ducky in each image; click `next image` to go to the next image and `submit` when there are no more images.\n",
"\n",
"If you'd like to skip the manual annotation step, we totally understand. In this case, simply skip this cell and run the next cell instead, where we've prepopulated the groundtruth with pre-annotated bounding boxes.\n",
"Below we add the class annotations (for simplicity, we assume a single class in this colab; though it should be straightforward to extend this to handle multiple classes). We also convert everything to the format that the training\n",
"loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "HWBqFVMcweF-"
},
"outputs": [],
"source": [
"\n",
"# By convention, our non-background classes start counting at 1. Given\n",
"# that we will be predicting just one class, we will therefore assign it a\n",
"# Create model and restore weights for all but last layer\n",
"\n",
"In this cell we build a single stage detection architecture (RetinaNet) and restore all but the classification layer at the top (which will be automatically randomly initialized).\n",
"\n",
"For simplicity, we have hardcoded a number of things in this colab for the specific RetinaNet architecture at hand (including assuming that the image size will always be 640x640), however it is not difficult to generalize to other model configurations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9J16r3NChD-7"
},
"outputs": [],
"source": [
"# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
"Welcome to the object detection colab! This demo will take you through the steps of running an \"out-of-the-box\" detection model on a collection of images."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vPs64QA1Zdov"
},
"source": [
"## Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LBZ9VWZZFUCT"
},
"outputs": [],
"source": [
"!pip install -U --pre tensorflow==\"2.2.0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "oi28cqGGFWnY"
},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# Clone the tensorflow models repository if it doesn't already exist\n",
"if \"models\" in pathlib.Path.cwd().parts:\n",
" while \"models\" in pathlib.Path.cwd().parts:\n",
"Label maps correspond index numbers to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine."
"Run the below code which loads an image, runs it through the detection model and visualizes the detection results, including the keypoints.\n",
"\n",
"Note that this will take a long time (several minutes) the first time you run this code due to tf.function's trace-compilation --- on subsequent runs (e.g. on new images), things will be faster.\n",
"\n",
"Here are some simple things to try out if you are curious:\n",
"* Try running inference on your own images (local paths work)\n",
"* Modify some of the input images and see if detection still works. Some simple things to try out here (just uncomment the relevant portions of code) include flipping the image horizontally, or converting to grayscale (note that we still expect the input image to have 3 channels).\n",
"* Print out `detections['detection_boxes']` and try to match the box locations to the boxes in the image. Notice that coordinates are given in normalized form (i.e., in the interval [0, 1]).\n",
"* Set min_score_thresh to other values (between 0 and 1) to allow more detections in or to filter out more detections.\n",
"\n",
"Note that you can run this cell repeatedly without rerunning earlier cells.\n"
"## Digging into the model's intermediate predictions\n",
"\n",
"For this part we will assume that the detection model is a CenterNet model following Zhou et al (https://arxiv.org/abs/1904.07850). And more specifically, we will assume that `detection_model` is of type `meta_architectures.center_net_meta_arch.CenterNetMetaArch`.\n",
"\n",
"As one of its intermediate predictions, CenterNet produces a heatmap of box centers for each class (for example, it will produce a heatmap whose size is proportional to that of the image that lights up at the center of each, e.g., \"zebra\"). In the following, we will visualize these intermediate class center heatmap predictions."