"Welcome to the **CenterNet on-device with TensorFlow Lite** Colab. Here, we demonstrate how you can run a mobile-optimized version of the [CenterNet](https://arxiv.org/abs/1904.08189) architecture with [TensorFlow Lite](https://www.tensorflow.org/lite) (a.k.a. TFLite). \r\n",
"\r\n",
"Users can use this notebook as a reference for obtaining TFLite version of CenterNet for *Object Detection* or [*Keypoint detection*](https://cocodataset.org/#keypoints-2020). The code also shows how to perform pre-/post-processing & inference with TFLite's Python API.\r\n",
"\r\n",
"**NOTE:** CenterNet support in TFLite is still experimental, and currently works with floating-point inference only."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3LQWTJ-BWzmW"
},
"source": [
"# Set Up"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gx84EpH7INPj"
},
"source": [
"## Libraries & Imports"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EU_hXi7IW9QC"
},
"source": [
"!pip install tf-nightly"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZTU9_JcOZz-J"
},
"source": [
"import os\r\n",
"import pathlib\r\n",
"\r\n",
"# Clone the tensorflow models repository if it doesn't already exist\r\n",
"if \"models\" in pathlib.Path.cwd().parts:\r\n",
" while \"models\" in pathlib.Path.cwd().parts:\r\n",
"The `detect` function shown below describes how input and output tensors from CenterNet (obtained in subsequent sections) can be processed. This logic can be ported to other languages depending on your application (for e.g. to Java for Android apps)."
"**NOTE:** Not all CenterNet models from the [TF2 Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) work with TFLite, only the [MobileNet-based version](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz) does.\r\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Sywt8MKzIeOi"
},
"source": [
"# Get mobile-friendly CenterNet for Object Detection\r\n",
"# See TensorFlow 2 Detection Model Zoo for more details:\r\n",
"Now that we have downloaded the CenterNet model that uses MobileNet as a backbone, we can obtain a TensorFlow Lite model from it. \r\n",
"\r\n",
"The downloaded archive already contains `model.tflite` that works with TensorFlow Lite, but we re-generate the model in the next sub-section to account for cases where you might re-train the model on your own dataset (with corresponding changes to `pipeline.config` & `checkpoint` directory)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jT0bruuxM496"
},
"source": [
"## Generate TensorFlow Lite Model\r\n",
"\r\n",
"First, we invoke `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\r\n",
"\r\n",
"This is similar to what we do for [SSD architectures](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "jpcCjiQ_JrU5",
"collapsed": true
},
"source": [
"%%bash\r\n",
"# Export the intermediate SavedModel that outputs 10 detections & takes in an \r\n",
"# image of dim 320x320.\r\n",
"# Modify these parameters according to your needs.\r\n",
"Unlike SSDs, CenterNet also supports COCO [Keypoint detection](https://cocodataset.org/#keypoints-2020). To be more specific, the 'keypoints' version of CenterNet shown here provides keypoints as a `[N, 17, 2]`-shaped tensor representing the (normalized) yx-coordinates of 17 COCO human keypoints.\r\n",
"\r\n",
"See the `detect()` function in the **Utilities for Inference** section to better understand the keypoints output."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xu47DkrDV18O"
},
"source": [
"## Download Model from Detection Zoo\r\n",
"\r\n",
"**NOTE:** Not all CenterNet models from the [TF2 Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) work with TFLite, only the [MobileNet-based version](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz) does."
]
},
{
"cell_type": "code",
"metadata": {
"id": "sd7f64WjWD7z"
},
"source": [
"# Get mobile-friendly CenterNet for Keypoint detection task.\r\n",
"# See TensorFlow 2 Detection Model Zoo for more details:\r\n",
"As before, we leverage `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TFLite converter to generating the final model.\r\n",
"\r\n",
"Note that we need to include an additional `keypoint_label_map_path` parameter for exporting the keypoints outputs."
]
},
{
"cell_type": "code",
"metadata": {
"id": "8kEhwYynX-cD"
},
"source": [
"%%bash\r\n",
"# Export the intermediate SavedModel that outputs 10 detections & takes in an \r\n",
"# image of dim 320x320.\r\n",
"# Modify these parameters according to your needs.\r\n",
"As mentioned earlier, both the above models can be run on mobile phones with TensorFlow Lite. See our [**inference documentation**](https://www.tensorflow.org/lite/guide/inference) for general guidelines on platform-specific APIs & leveraging hardware acceleration. Both the object-detection & keypoint-detection versions of CenterNet are compatible with our [GPU delegate](https://www.tensorflow.org/lite/performance/gpu). *We are working on developing quantized versions of this model.*\r\n",
"\r\n",
"To leverage *object-detection* in your Android app, the simplest way is to use TFLite's [**ObjectDetector Task API**](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector). It is a high-level API that encapsulates complex but common image processing and post processing logic. Inference can be done in 5 lines of code. It is supported in Java for Android and C++ for native code. *We are working on building the Swift API for iOS, as well as the support for the keypoint-detection model.*\r\n",
"\r\n",
"To use the Task API, the model needs to be packed with [TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata). This metadata helps the inference code perform the correct pre & post processing as required by the model. Use the following code to create the metadata."
"See more information about *object-detection* models from our [public documentation](https://www.tensorflow.org/lite/examples/object_detection/overview). The [Object Detection example app](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection) is a good starting point for integrating that model into your Android and iOS app. You can find [examples](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android#switch-between-inference-solutions-task-library-vs-tflite-interpreter) of using both the TFLite Task Library and TFLite Interpreter API."