"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "bc8e6d24c21ed2939025a066095f3cab198547db"
Commit 72881871 authored by Khanh LeViet's avatar Khanh LeViet Committed by TF Object Detection Team
Browse files

Add notebook and update ODT TFLite conversion doc to show how to make the...

Add notebook and update ODT TFLite conversion doc to show how to make the model compatible with TFLite Task Library.

PiperOrigin-RevId: 378313314
parent 19738a07
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "RD3uxzaJweYr"
},
"source": [
"##### Copyright 2021 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "C-vBUz5IhJs8"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pHTibyMehTvH"
},
"source": [
"# Tutorial: Convert models trained using TensorFlow Object Detection API to TensorFlow Lite\n",
"\n",
"This tutorial demonstrate these steps:\n",
"* Convert TensorFlow models trained using the TensorFlow Object Detection API to [TensorFlow Lite](https://www.tensorflow.org/lite).\n",
"* Add the required metadata using [TFLite Metadata Writer API](https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#object_detectors). This will make the TFLite model compatible with [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector), so that the model can be integrated in mobile apps in 3 lines of code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QIR1IFpnLJJA"
},
"source": [
"\u003ctable align=\"left\"\u003e\u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb\"\u003e\n",
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\n",
" \u003c/a\u003e\n",
"\u003c/td\u003e\u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb\"\u003e\n",
" \u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
"\u003c/td\u003e\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ok_Rpv7XNaFJ"
},
"source": [
"## Preparation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t7CAW5C1cmel"
},
"source": [
"### Install the TFLite Support Library"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DwtFa0jSnNU4"
},
"outputs": [],
"source": [
"!pip install -q tflite_support"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XRfJR9QXctAR"
},
"source": [
"### Install the TensorFlow Object Detection API\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7PP2P5XAqeI5"
},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# Clone the tensorflow models repository if it doesn't already exist\n",
"if \"models\" in pathlib.Path.cwd().parts:\n",
" while \"models\" in pathlib.Path.cwd().parts:\n",
" os.chdir('..')\n",
"elif not pathlib.Path('models').exists():\n",
" !git clone --depth 1 https://github.com/tensorflow/models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bP6SSh6zqi07"
},
"outputs": [],
"source": [
"%%bash\n",
"cd models/research/\n",
"protoc object_detection/protos/*.proto --python_out=.\n",
"cp object_detection/packages/tf2/setup.py .\n",
"pip install -q ."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i0to7aXKc0O9"
},
"source": [
"### Import the necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4M8CC1PgqnSf"
},
"outputs": [],
"source": [
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import os\n",
"import random\n",
"import io\n",
"import imageio\n",
"import glob\n",
"import scipy.misc\n",
"import numpy as np\n",
"from six import BytesIO\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from IPython.display import display, Javascript\n",
"from IPython.display import Image as IPyImage\n",
"\n",
"import tensorflow as tf\n",
"\n",
"from object_detection.utils import label_map_util\n",
"from object_detection.utils import config_util\n",
"from object_detection.utils import visualization_utils as viz_utils\n",
"from object_detection.utils import colab_utils\n",
"from object_detection.utils import config_util\n",
"from object_detection.builders import model_builder\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s9WIOOMTNti5"
},
"source": [
"## Download a pretrained model from Model Zoo\n",
"\n",
"In this tutorial, we demonstrate converting a pretrained model `SSD MobileNet V2 FPNLite 640x640` in the [TensorFlow 2 Model Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md). You can replace the model with your own model and the rest will work the same."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TIY3cxDgsxuZ"
},
"outputs": [],
"source": [
"!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz\n",
"!tar -xf ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz\n",
"!rm ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0gV8vr6nN-z9"
},
"source": [
"## Generate TensorFlow Lite Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z8FjeSmmxpXz"
},
"source": [
"### Step 1: Export TFLite inference graph\n",
"\n",
"First, we invoke `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\n",
"\n",
"Use `--help` with the above script to get the full list of supported parameters.\n",
"These can fine-tune accuracy and speed for your model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ChfN-tzBXqko"
},
"outputs": [],
"source": [
"!python models/research/object_detection/export_tflite_graph_tf2.py \\\n",
" --trained_checkpoint_dir {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/checkpoint'} \\\n",
" --output_directory {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite'} \\\n",
" --pipeline_config_path {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/pipeline.config'}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IPr06cZ3OY3H"
},
"source": [
"### Step 2: Convert to TFLite\n",
"\n",
"Use the [TensorFlow Lite Converter](https://www.tensorflow.org/lite/convert) to\n",
"convert the `SavedModel` to TFLite. Note that you need to use `from_saved_model`\n",
"for TFLite conversion with the Python API.\n",
"\n",
"You can also leverage\n",
"[Post-training Quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)\n",
"to\n",
"[optimize performance](https://www.tensorflow.org/lite/performance/model_optimization)\n",
"and obtain a smaller model. In this tutorial, we use the [dynamic range quantization](https://www.tensorflow.org/lite/performance/post_training_quant)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JMpy3Rlpq-Yq"
},
"outputs": [],
"source": [
"_TFLITE_MODEL_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/model.tflite\"\n",
"\n",
"converter = tf.lite.TFLiteConverter.from_saved_model('ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite/saved_model')\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"tflite_model = converter.convert()\n",
"\n",
"with open(_TFLITE_MODEL_PATH, 'wb') as f:\n",
" f.write(tflite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fyjlnmaEOtKp"
},
"source": [
"### Step 3: Add Metadata\n",
"\n",
"The model needs to be packed with [TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata) to enable easy integration into mobile apps using the [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector). This metadata helps the inference code perform the correct pre \u0026 post processing as required by the model. Use the following code to create the metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-ecGLG_Ovjcr"
},
"outputs": [],
"source": [
"# Download the COCO dataset label map that was used to trained the SSD MobileNet V2 FPNLite 640x640 model\n",
"!wget https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt -q\n",
"\n",
"# We need to convert the Object Detection API's labelmap into what the Task API needs:\n",
"# a txt file with one class name on each line from index 0 to N.\n",
"# The first '0' class indicates the background.\n",
"# This code assumes COCO detection which has 90 classes, you can write a label\n",
"# map file for your model if re-trained.\n",
"_ODT_LABEL_MAP_PATH = 'mscoco_label_map.pbtxt'\n",
"_TFLITE_LABEL_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite_label_map.txt\"\n",
"\n",
"category_index = label_map_util.create_category_index_from_labelmap(\n",
" _ODT_LABEL_MAP_PATH)\n",
"f = open(_TFLITE_LABEL_PATH, 'w')\n",
"for class_id in range(1, 91):\n",
" if class_id not in category_index:\n",
" f.write('???\\n')\n",
" continue\n",
" name = category_index[class_id]['name']\n",
" f.write(name+'\\n')\n",
"f.close()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YJSyXq5Qss9X"
},
"source": [
"Then we'll add the label map and other necessary metadata (e.g. normalization config) to the TFLite model.\n",
"\n",
"As the `SSD MobileNet V2 FPNLite 640x640` model take input image with pixel value in the range of [-1..1] ([code](https://github.com/tensorflow/models/blob/b09e75828e2c65ead9e624a5c7afed8d214247aa/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py#L132)), we need to set `norm_mean = 127.5` and `norm_std = 127.5`. See this [documentation](https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters) for more details on the normalization parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CRQpfDAWsPeK"
},
"outputs": [],
"source": [
"from tflite_support.metadata_writers import object_detector\n",
"from tflite_support.metadata_writers import writer_utils\n",
"\n",
"_TFLITE_MODEL_WITH_METADATA_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/model_with_metadata.tflite\"\n",
"\n",
"writer = object_detector.MetadataWriter.create_for_inference(\n",
" writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[127.5], \n",
" input_norm_std=[127.5], label_file_paths=[_TFLITE_LABEL_PATH])\n",
"writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YFEAjRBdPCQb"
},
"source": [
"Optional: Print out the metadata added to the TFLite model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FT3-38PJsSOt"
},
"outputs": [],
"source": [
"from tflite_support import metadata\n",
"\n",
"displayer = metadata.MetadataDisplayer.with_model_file(_TFLITE_MODEL_WITH_METADATA_PATH)\n",
"print(\"Metadata populated:\")\n",
"print(displayer.get_metadata_json())\n",
"print(\"=============================\")\n",
"print(\"Associated file(s) populated:\")\n",
"print(displayer.get_packed_associated_file_list())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l7zVslTRnEHX"
},
"source": [
"The TFLite model now can be integrated into a mobile app using the TFLite Task Library. See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector) for more details."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Convert TF Object Detection API model to TFLite.ipynb",
"private_outputs": true,
"provenance": [
{
"file_id": "1R4_y-u14YTdvBzhmvC0HQwh3HkcCN2Bd",
"timestamp": 1623114733432
},
{
"file_id": "1Rey5kAzNQhJ77tsXGjhcAV0UZ6du0Sla",
"timestamp": 1622897882140
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
...@@ -13,17 +13,22 @@ on-device machine learning inference with low latency and a small binary size. ...@@ -13,17 +13,22 @@ on-device machine learning inference with low latency and a small binary size.
TensorFlow Lite uses many techniques for this such as quantized kernels that TensorFlow Lite uses many techniques for this such as quantized kernels that
allow smaller and faster (fixed-point math) models. allow smaller and faster (fixed-point math) models.
This document shows how elgible models from the This document shows how eligible models from the
[TF2 Detection zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) [TF2 Detection zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)
can be converted for inference with TFLite. can be converted for inference with TFLite. See this Colab tutorial for a
runnable tutorial that walks you through the steps explained in this document:
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run
in Google Colab</a>
For an end-to-end Python guide on how to fine-tune an SSD model for mobile For an end-to-end Python guide on how to fine-tune an SSD model for mobile
inference, look at inference, look at
[this Colab](../colab_tutorials/eager_few_shot_od_training_tflite.ipynb). [this Colab](../colab_tutorials/eager_few_shot_od_training_tflite.ipynb).
**NOTE:** TFLite currently only supports **SSD Architectures** (excluding **NOTE:** TFLite currently only supports **SSD Architectures** (excluding
EfficientDet) for boxes-based detection. Support for EfficientDet is coming EfficientDet) for boxes-based detection. Support for EfficientDet is provided
soon. via the [TFLite Model Maker](https://www.tensorflow.org/lite/tutorials/model_maker_object_detection)
library.
The output model has the following inputs & outputs: The output model has the following inputs & outputs:
...@@ -87,9 +92,46 @@ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, ...@@ -87,9 +92,46 @@ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
converter.representative_dataset = <...> converter.representative_dataset = <...>
``` ```
### Step 3: Add Metadata
The model needs to be packed with
[TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata) to enable
easy integration into mobile apps using the
[TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector).
This metadata helps the inference code perform the correct pre & post processing
as required by the model. Use the following code to create the metadata.
```python
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
writer = object_detector.MetadataWriter.create_for_inference(
writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[0],
input_norm_std=[255], label_file_paths=[_TFLITE_LABEL_PATH])
writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)
```
See the TFLite Metadata Writer API [documentation](https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#object_detectors)
for more details.
## Running our model on Android ## Running our model on Android
To run our TensorFlow Lite model on device, we will use Android Studio to build ### Integrate the model into your app
You can use the TFLite Task Library's [ObjectDetector API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector)
to integrate the model into your Android app.
```java
// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder().setMaxResults(1).build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Detection> results = objectDetector.detect(image);
```
### Test the model using the TFLite sample app
To test our TensorFlow Lite model on device, we will use Android Studio to build
and run the TensorFlow Lite detection example with the new model. The example is and run the TensorFlow Lite detection example with the new model. The example is
found in the found in the
[TensorFlow examples repository](https://github.com/tensorflow/examples) under [TensorFlow examples repository](https://github.com/tensorflow/examples) under
...@@ -102,7 +144,7 @@ that support API >= 21. Additional details are available on the ...@@ -102,7 +144,7 @@ that support API >= 21. Additional details are available on the
Next we need to point the app to our new detect.tflite file and give it the Next we need to point the app to our new detect.tflite file and give it the
names of our new labels. Specifically, we will copy our TensorFlow Lite names of our new labels. Specifically, we will copy our TensorFlow Lite
flatbuffer to the app assets directory with the following command: model with metadata to the app assets directory with the following command:
```shell ```shell
mkdir $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets mkdir $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets
...@@ -110,9 +152,6 @@ cp /tmp/tflite/detect.tflite \ ...@@ -110,9 +152,6 @@ cp /tmp/tflite/detect.tflite \
$TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets
``` ```
You will also need to copy your new labelmap labelmap.txt to the assets
directory.
We will now edit the gradle build file to use these assets. First, open the We will now edit the gradle build file to use these assets. First, open the
`build.gradle` file `build.gradle` file
`$TF_EXAMPLES/lite/examples/object_detection/android/app/build.gradle`. Comment `$TF_EXAMPLES/lite/examples/object_detection/android/app/build.gradle`. Comment
...@@ -122,23 +161,12 @@ out the model download script to avoid your assets being overwritten: ...@@ -122,23 +161,12 @@ out the model download script to avoid your assets being overwritten:
// apply from:'download_model.gradle' // apply from:'download_model.gradle'
``` ```
If your model is named `detect.tflite`, and your labels file `labelmap.txt`, the If your model is named `detect.tflite`, the example will use it automatically as
example will use them automatically as long as they've been properly copied into long as they've been properly copied into the base assets directory. If you need
the base assets directory. If you need to use a custom path or filename, open up to use a custom path or filename, open up the
the
$TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
file in a text editor and find the definition of TF_OD_API_LABELS_FILE. Update file in a text editor and find the definition of TF_OD_API_MODEL_FILE. Update
this path to point to your new label map file: "labels_list.txt". Note that if this path to point to your new model file.
your model is quantized, the flag TF_OD_API_IS_QUANTIZED is set to true, and if
your model is floating point, the flag TF_OD_API_IS_QUANTIZED is set to false.
This new section of DetectorActivity.java should now look as follows for a
quantized model:
```java
private static final boolean TF_OD_API_IS_QUANTIZED = true;
private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "labels_list.txt";
```
Once you’ve copied the TensorFlow Lite model and edited the gradle build script Once you’ve copied the TensorFlow Lite model and edited the gradle build script
to not use the downloaded assets, you can build and deploy the app using the to not use the downloaded assets, you can build and deploy the app using the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment