{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Image classification with Model Garden" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " \u003c/td\u003e\n", "\u003c/table\u003e" ] }, { "cell_type": "markdown", "metadata": { "id": "Ta_nFXaVAqLD" }, "source": [ "This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n", "\n", "Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n", "\n", "This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n", "\n", "This tutorial demonstrates how to:\n", "1. Use models from the TensorFlow Models package.\n", "2. Fine-tune a pre-built ResNet for image classification.\n", "3. Export the tuned ResNet model." ] }, { "cell_type": "markdown", "metadata": { "id": "G2FlaQcEPOER" }, "source": [ "## Setup\n", "\n", "Install and import the necessary modules." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XvWfdCrvrV5W" }, "outputs": [], "source": [ "!pip install -U -q \"tf-models-official\"" ] }, { "cell_type": "markdown", "metadata": { "id": "CKYMTPjOE400" }, "source": [ "Import TensorFlow, TensorFlow Datasets, and a few helper libraries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wlon1uoIowmZ" }, "outputs": [], "source": [ "import pprint\n", "import tempfile\n", "\n", "from IPython import display\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "AVTs0jDd1b24" }, "source": [ "The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NHT1iiIiBzlC" }, "outputs": [], "source": [ "import tensorflow_models as tfm\n", "\n", "# These are not in the tfm public API for v2.9. They will be available in v2.10\n", "from official.vision.serving import export_saved_model_lib\n", "import official.core.train_lib" ] }, { "cell_type": "markdown", "metadata": { "id": "aKv3wdqkQ8FU" }, "source": [ "## Configure the ResNet-18 model for the Cifar-10 dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "5iN8mHEJjKYE" }, "source": [ "The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n", "\n", "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n", "\n", "Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1M77f88Dj2Td" }, "outputs": [], "source": [ "exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n", "tfds_name = 'cifar10'\n", "ds,ds_info = tfds.load(\n", "tfds_name,\n", "with_info=True)\n", "ds_info" ] }, { "cell_type": "markdown", "metadata": { "id": "U6PVwXA-j3E7" }, "source": [ "Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YWI7faVStQaV" }, "outputs": [], "source": [ "# Configure model\n", "exp_config.task.model.num_classes = 10\n", "exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n", "exp_config.task.model.backbone.resnet.model_id = 18\n", "\n", "# Configure training and testing data\n", "batch_size = 128\n", "\n", "exp_config.task.train_data.input_path = ''\n", "exp_config.task.train_data.tfds_name = tfds_name\n", "exp_config.task.train_data.tfds_split = 'train'\n", "exp_config.task.train_data.global_batch_size = batch_size\n", "\n", "exp_config.task.validation_data.input_path = ''\n", "exp_config.task.validation_data.tfds_name = tfds_name\n", "exp_config.task.validation_data.tfds_split = 'test'\n", "exp_config.task.validation_data.global_batch_size = batch_size\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DE3ggKzzTD56" }, "source": [ "Adjust the trainer configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "inE_-4UGkLud" }, "outputs": [], "source": [ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n", "\n", "if 'GPU' in ''.join(logical_device_names):\n", " print('This may be broken in Colab.')\n", " device = 'GPU'\n", "elif 'TPU' in ''.join(logical_device_names):\n", " print('This may be broken in Colab.')\n", " device = 'TPU'\n", "else:\n", " print('Running on CPU is slow, so only train for a few steps.')\n", " device = 'CPU'\n", "\n", "if device=='CPU':\n", " train_steps = 20\n", " exp_config.trainer.steps_per_loop = 5\n", "else:\n", " train_steps=5000\n", " exp_config.trainer.steps_per_loop = 100\n", "\n", "exp_config.trainer.summary_interval = 100\n", "exp_config.trainer.checkpoint_interval = train_steps\n", "exp_config.trainer.validation_interval = 1000\n", "exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n", "exp_config.trainer.train_steps = train_steps\n", "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n", "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n", "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n", "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100" ] }, { "cell_type": "markdown", "metadata": { "id": "5mTcDnBiTOYD" }, "source": [ "Print the modified configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tuVfxSBCTK-y" }, "outputs": [], "source": [ "pprint.pprint(exp_config.as_dict())\n", "\n", "display.Javascript(\"google.colab.output.setIframeHeight('300px');\")" ] }, { "cell_type": "markdown", "metadata": { "id": "w7_X0UHaRF2m" }, "source": [ "Set up the distribution strategy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ykL14FIbTaSt" }, "outputs": [], "source": [ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n", "\n", "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n", " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n", "\n", "if 'GPU' in ''.join(logical_device_names):\n", " distribution_strategy = tf.distribute.MirroredStrategy()\n", "elif 'TPU' in ''.join(logical_device_names):\n", " tf.tpu.experimental.initialize_tpu_system()\n", " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n", " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n", "else:\n", " print('Warning: this will be really slow.')\n", " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "W4k5YH5pTjaK" }, "source": [ "Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n", "\n", "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6MgYSH0PtUaW" }, "outputs": [], "source": [ "with distribution_strategy.scope():\n", " model_dir = tempfile.mkdtemp()\n", " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n", "\n", "# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFXEZYdzBKoX" }, "outputs": [], "source": [ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " print()\n", " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n", " print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')" ] }, { "cell_type": "markdown", "metadata": { "id": "yrwxnGDaRU0U" }, "source": [ "## Visualize the training data" ] }, { "cell_type": "markdown", "metadata": { "id": "683c255c6c52" }, "source": [ "The dataloader applies a z-score normalization using \n", "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PdmOz2EC0Nx2" }, "outputs": [], "source": [ "plt.hist(images.numpy().flatten());" ] }, { "cell_type": "markdown", "metadata": { "id": "7a8582ebde7b" }, "source": [ "Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wq4Wq_CuDG3Q" }, "outputs": [], "source": [ "label_info = ds_info.features['label']\n", "label_info.int2str(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "8c652a6fdbcf" }, "source": [ "Visualize a batch of the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZKfTxytf1l0d" }, "outputs": [], "source": [ "def show_batch(images, labels, predictions=None):\n", " plt.figure(figsize=(10, 10))\n", " min = images.numpy().min()\n", " max = images.numpy().max()\n", " delta = max - min\n", "\n", " for i in range(12):\n", " plt.subplot(6, 6, i + 1)\n", " plt.imshow((images[i]-min) / delta)\n", " if predictions is None:\n", " plt.title(label_info.int2str(labels[i]))\n", " else:\n", " if labels[i] == predictions[i]:\n", " color = 'g'\n", " else:\n", " color = 'r'\n", " plt.title(label_info.int2str(predictions[i]), color=color)\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xkA5h_RBtYYU" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " show_batch(images, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "v_A9VnL2RbXP" }, "source": [ "## Visualize the testing data" ] }, { "cell_type": "markdown", "metadata": { "id": "AXovuumW_I2z" }, "source": [ "Visualize a batch of images from the validation dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ma-_Eb-nte9A" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10));\n", "for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n", " show_batch(images, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "ihKJt2FHRi2N" }, "source": [ "## Train and evaluate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0AFMNvYxtjXx" }, "outputs": [], "source": [ "model, eval_logs = tfm.core.train_lib.run_experiment(\n", " distribution_strategy=distribution_strategy,\n", " task=task,\n", " mode='train_and_eval',\n", " params=exp_config,\n", " model_dir=model_dir,\n", " run_post_eval=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gCcHMQYhozmA" }, "outputs": [], "source": [ "# tf.keras.utils.plot_model(model, show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "L7nVfxlBA8Gb" }, "source": [ "Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0124f938a1b9" }, "outputs": [], "source": [ "for key, value in eval_logs.items():\n", " if isinstance(value, tf.Tensor):\n", " value = value.numpy()\n", " print(f'{key:20}: {value:.3f}')" ] }, { "cell_type": "markdown", "metadata": { "id": "TDys5bZ1zsml" }, "source": [ "Run a batch of the processed training data through the model, and view the results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GhI7zR-Uz1JT" }, "outputs": [], "source": [ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " predictions = model.predict(images)\n", " predictions = tf.argmax(predictions, axis=-1)\n", "\n", "show_batch(images, labels, tf.cast(predictions, tf.int32))\n", "\n", "if device=='CPU':\n", " plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')" ] }, { "cell_type": "markdown", "metadata": { "id": "fkE9locGTBgt" }, "source": [ "## Export a SavedModel" ] }, { "cell_type": "markdown", "metadata": { "id": "9669d08c91af" }, "source": [ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AQCFa7BvtmDg" }, "outputs": [], "source": [ "# Saving and exporting the trained model\n", "export_saved_model_lib.export_inference_graph(\n", " input_type='image_tensor',\n", " batch_size=1,\n", " input_image_size=[32, 32],\n", " params=exp_config,\n", " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n", " export_dir='./export/')" ] }, { "cell_type": "markdown", "metadata": { "id": "vVr6DxNqTyLZ" }, "source": [ "Test the exported model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gP7nOvrftsB0" }, "outputs": [], "source": [ "# Importing SavedModel\n", "imported = tf.saved_model.load('./export/')\n", "model_fn = imported.signatures['serving_default']" ] }, { "cell_type": "markdown", "metadata": { "id": "GiOp2WVIUNUZ" }, "source": [ "Visualize the predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BTRMrZQAN4mk" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for data in tfds.load('cifar10', split='test').batch(12).take(1):\n", " predictions = []\n", " for image in data['image']:\n", " index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n", " predictions.append(index)\n", " show_batch(data['image'], data['label'], predictions)\n", "\n", " if device=='CPU':\n", " plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')" ] } ], "metadata": { "colab": { "name": "classification_with_model_garden.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }