"Welcome to the **Few Shot Object Detection for TensorFlow Lite** Colab. Here, we demonstrate fine tuning of a SSD architecture (pre-trained on COCO) on very few examples of a *novel* class. We will then generate a (downloadable) TensorFlow Lite model for on-device inference.\n",
"\n",
"**NOTE:** This Colab is meant for the few-shot detection use-case. To train a model on a large dataset, please follow the [TF2 training](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_training_and_evaluation.md#training) documentation and then [convert](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md) the model to TensorFlow Lite."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3U2sv0upw04O"
},
"source": [
"# Set Up"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vPs64QA1Zdov"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H0rKBV4uZacD"
},
"outputs": [],
"source": [
"# Support for TF2 models was added after TF 2.3.\n",
"!pip install tf-nightly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oi28cqGGFWnY"
},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# Clone the tensorflow models repository if it doesn't already exist\n",
"if \"models\" in pathlib.Path.cwd().parts:\n",
" while \"models\" in pathlib.Path.cwd().parts:\n",
"We will start with some toy data consisting of 5 images of a rubber\n",
"ducky. Note that the [COCO](https://cocodataset.org/#explore) dataset contains a number of animals, but notably, it does *not* contain rubber duckies (or even ducks for that matter), so this is a novel class."
"for idx, train_image_np in enumerate(train_images_np):\n",
" plt.subplot(2, 3, idx+1)\n",
" plt.imshow(train_image_np)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LbOe9Ym7xMGV"
},
"source": [
"# Transfer Learning\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dqb_yjAo3cO_"
},
"source": [
"## Data Preparation\n",
"\n",
"First, we populate the groundtruth with pre-annotated bounding boxes.\n",
"\n",
"We then add the class annotations (for simplicity, we assume a single 'Duck' class in this colab; though it should be straightforward to extend this to handle multiple classes). We also convert everything to the format that the training\n",
"loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.)."
"In this cell we build a mobile-friendly single-stage detection architecture (SSD MobileNet V2 FPN-Lite) and restore all but the classification layer at the top (which will be randomly initialized).\n",
"\n",
"**NOTE**: TensorFlow Lite only supports SSD models for now.\n",
"\n",
"For simplicity, we have hardcoded a number of things in this colab for the specific SSD architecture at hand (including assuming that the image size will always be 320x320), however it is not difficult to generalize to other model configurations (`pipeline.config` in the zip downloaded from the [Model Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.)).\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9J16r3NChD-7"
},
"outputs": [],
"source": [
"# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
"Some of the parameters in this block have been set empirically: for example, `learning_rate`, `num_batches` \u0026 `momentum` for SGD. These are just a starting point, you will have to tune these for your data \u0026 model architecture to get the best results.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nyHoF4mUrv5-"
},
"outputs": [],
"source": [
"tf.keras.backend.set_learning_phase(True)\n",
"\n",
"# These parameters can be tuned; since our training set has 5 images\n",
"# it doesn't make sense to have a much larger batch size, though we could\n",
"# fit more examples in memory if we wanted to.\n",
"batch_size = 5\n",
"learning_rate = 0.15\n",
"num_batches = 1000\n",
"\n",
"# Select variables in top layers to fine-tune.\n",
"First, we invoke the `export_tflite_graph_tf2.py` script to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\n",
"\n",
"To know more about this process, please look at [this documentation](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md)."
"This model can be run on-device with **TensorFlow Lite**. Look at [our SSD model signature](https://www.tensorflow.org/lite/models/object_detection/overview#uses_and_limitations) to understand how to interpret the model IO tensors. Our [Object Detection example](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection) is a good starting point for integrating the model into your mobile app.\n",
"\n",
"Refer to TFLite's [inference documentation](https://www.tensorflow.org/lite/guide/inference) for more details."